diff --git a/DiffSynth-Studio/diffsynth/__init__.py b/DiffSynth-Studio/diffsynth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb67a43fa4e5791ab58e7e40260bc3df8b6bc7cc --- /dev/null +++ b/DiffSynth-Studio/diffsynth/__init__.py @@ -0,0 +1 @@ +from .core import * diff --git a/DiffSynth-Studio/diffsynth/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f16c0b6c948ece347efaff0cd70c0aa049a8a10b Binary files /dev/null and b/DiffSynth-Studio/diffsynth/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/configs/__init__.py b/DiffSynth-Studio/diffsynth/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..144a822978b12a8297e341760a74791fb12802c0 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/configs/__init__.py @@ -0,0 +1,2 @@ +from .model_configs import MODEL_CONFIGS +from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS diff --git a/DiffSynth-Studio/diffsynth/configs/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/configs/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bc4adb6deabe0145943dafa35b94d150ac4e27a Binary files /dev/null and b/DiffSynth-Studio/diffsynth/configs/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/configs/__pycache__/model_configs.cpython-39.pyc b/DiffSynth-Studio/diffsynth/configs/__pycache__/model_configs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6133ae70f023a07ca0072c6ce77652f6af5fbf14 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/configs/__pycache__/model_configs.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/configs/__pycache__/vram_management_module_maps.cpython-39.pyc b/DiffSynth-Studio/diffsynth/configs/__pycache__/vram_management_module_maps.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5fdf94f7855e65fbbb38cb001f6a875ff95c1c8 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/configs/__pycache__/vram_management_module_maps.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/configs/model_configs.py b/DiffSynth-Studio/diffsynth/configs/model_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..c93f5e9a311096bec70d5823881dead29f772ab8 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/configs/model_configs.py @@ -0,0 +1,594 @@ +qwen_image_series = [ + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors") + "model_hash": "0319a1cb19835fb510907dd3367c95ff", + "model_name": "qwen_image_dit", + "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT", + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "8004730443f55db63092006dd9f7110e", + "model_name": "qwen_image_text_encoder", + "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "ed4ea5824d55ec3107b09815e318123a", + "model_name": "qwen_image_vae", + "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors") + "model_hash": "073bce9cf969e317e5662cd570c3e79c", + "model_name": "qwen_image_blockwise_controlnet", + "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors") + "model_hash": "a9e54e480a628f0b956a688a81c33bab", + "model_name": "qwen_image_blockwise_controlnet", + "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet", + "extra_kwargs": {"additional_in_dim": 4}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors") + "model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8", + "model_name": "siglip2_image_encoder", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors") + "model_hash": "5722b5c873720009de96422993b15682", + "model_name": "dinov3_image_encoder", + "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder", + }, + { + # Example: + "model_hash": "a166c33455cdbd89c0888a3645ca5c0f", + "model_name": "qwen_image_image2lora_coarse", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + }, + { + # Example: + "model_hash": "a5476e691767a4da6d3a6634a10f7408", + "model_name": "qwen_image_image2lora_fine", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64} + }, + { + # Example: + "model_hash": "0aad514690602ecaff932c701cb4b0bb", + "model_name": "qwen_image_image2lora_style", + "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 64, "use_residual": False} + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "8dc8cda05de16c73afa755e2c1ce2839", + "model_name": "qwen_image_dit", + "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT", + "extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True} + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "44b39ddc499e027cfb24f7878d7416b9", + "model_name": "qwen_image_vae", + "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE", + "extra_kwargs": {"image_channels": 4} + }, +] + +wan_series = [ + { + # Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors") + "model_hash": "5ec04e02b42d2580483ad69f4e76346a", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth") + "model_hash": "9c8818c2cbea55eca56c7b447df170da", + "model_name": "wan_video_text_encoder", + "model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth") + "model_hash": "ccc42284ea13e1ad04693284c7a09be6", + "model_name": "wan_video_vae", + "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors") + "model_hash": "8b27900f680d7251ce44e2dc8ae1ffef", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel", + }, + { + # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "5f90e66a0672219f12d9a626c8c21f61", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers" + }, + { + # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "5f90e66a0672219f12d9a626c8c21f61", + "model_name": "wan_video_vap", + "model_class": "diffsynth.models.wan_video_mot.MotWanModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") + "model_hash": "5941c53e207d62f20f9025686193c40b", + "model_name": "wan_video_image_encoder", + "model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter" + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors") + "model_hash": "dbd5ec76bbf977983f972c151d545389", + "model_name": "wan_video_motion_controller", + "model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "9269f8db9040a9d860eaca435be61814", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "349723183fc063b2bfc10bb2835cf677", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "6d6ccde6845b95ad9114ab993d917893", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "efa44cddf936c70abd0ea28b6cbe946c", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "6bfcfb3b342cb286ce886889d519a77e", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "70ddad9d3a133785da5ea371aae09504", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "b61c605c2adbd23124d152ed28e049ae", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "26bde73488a92e64cc20b0a7485b9e5b", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} + }, + { + # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "a61453409b67cd3246cf0c3bebad47ba", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "a61453409b67cd3246cf0c3bebad47ba", + "model_name": "wan_video_vace", + "model_class": "diffsynth.models.wan_video_vace.VaceWanModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "7a513e1f257a861512b1afd387a8ecd9", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "7a513e1f257a861512b1afd387a8ecd9", + "model_name": "wan_video_vace", + "model_class": "diffsynth.models.wan_video_vace.VaceWanModel", + "extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "31fa352acb8a1b1d33cd8764273d80a2", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter" + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "31fa352acb8a1b1d33cd8764273d80a2", + "model_name": "wan_video_animate_adapter", + "model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter" + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "47dbeab5e560db3180adf51dc0232fb1", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False} + }, + { + # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "2267d489f0ceb9f21836532952852ee5", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False}, + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors") + "model_hash": "5b013604280dd715f8457c6ed6d6a626", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "966cffdcc52f9c46c391768b27637614", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel", + "extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors") + "model_hash": "1f5ab7703c6fc803fdded85ff040c316", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True} + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth") + "model_hash": "e1de6c02cdac79f8b739f4d3698cd216", + "model_name": "wan_video_vae", + "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors") + "model_hash": "06be60f3a4526586d8431cd038a71486", + "model_name": "wans2v_audio_encoder", + "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter", + }, +] + +flux_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors") + "model_hash": "a29710fea6dddb0314663ee823598e50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Supported due to historical reasons. + "model_hash": "605c56eab23e9e2af863ad8f0813a25d", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", + "model_name": "flux_text_encoder_clip", + "model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors") + "model_hash": "22540b49eaedbc2f2784b2091a234c7c", + "model_name": "flux_text_encoder_t5", + "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors") + "model_hash": "21ea55f476dfc4fd135587abb59dfe5d", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors") + "model_hash": "d02f41c13549fa5093d3521f62a5570a", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "extra_kwargs": {'input_dim': 196, 'num_blocks': 8}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") + "model_hash": "0629116fce1472503a66992f96f3eb1a", + "model_name": "flux_value_controller", + "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder", + }, + { + # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "52357cb26250681367488a8954c271e8", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}, + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "78d18b9101345ff695f312e7e62538c0", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}, + }, + { + # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "b001c89139b5f053c715fe772362dd2a", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_single_blocks": 0}, + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin") + "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481", + "model_name": "infiniteyou_image_projector", + "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors") + "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors") + "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab", + "model_name": "flux_lora_encoder", + "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors") + "model_hash": "30143afb2dea73d1ac580e0787628f8c", + "model_name": "flux_lora_patcher", + "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors") + "model_hash": "2bd19e845116e4f875a0a048e27fc219", + "model_name": "nexus_gen_llm", + "model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "nexus_gen_editing_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "nexus_gen_generation_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin") + "model_hash": "4daaa66cc656a8fe369908693dad0a35", + "model_name": "flux_ipadapter", + "model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors") + "model_hash": "04d8c1e20a1f1b25f7434f111992a33f", + "model_name": "siglip_vision_model", + "model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50", + "model_name": "step1x_connector", + "model_class": "diffsynth.models.step1x_connector.Qwen2Connector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + "extra_kwargs": {"disable_guidance_embedder": True}, + }, + { + # Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors") + "model_hash": "3394f306c4cbf04334b712bf5aaed95f", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, +] + +flux2_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "28fca3d8e5bf2a2d1271748a773f6757", + "model_name": "flux2_text_encoder", + "model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors") + "model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "c54288e3ee12ca215898840682337b95", + "model_name": "flux2_vae", + "model_class": "diffsynth.models.flux2_vae.Flux2VAE", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "3bde7b817fec8143028b6825a63180df", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20} + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "9195f3ea256fcd0ae6d929c203470754", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "8B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors") + "model_hash": "39c6fc48f07bebecedbbaa971ff466c8", + "model_name": "flux2_dit", + "model_class": "diffsynth.models.flux2_dit.Flux2DiT", + "extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24} + }, +] + +z_image_series = [ + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors") + "model_hash": "fc3a8a1247fe185ce116ccbe0e426c28", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors") + "model_hash": "0f050f62a88876fea6eae0a18dac5a2e", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors") + "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3", + "model_name": "flux_vae_encoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers", + "extra_kwargs": {"use_conv_attention": False}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors") + "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3", + "model_name": "flux_vae_decoder", + "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", + "extra_kwargs": {"use_conv_attention": False}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors") + "model_hash": "aa3563718e5c3ecde3dfbb020ca61180", + "model_name": "z_image_dit", + "model_class": "diffsynth.models.z_image_dit.ZImageDiT", + "extra_kwargs": {"siglip_feat_dim": 1152}, + }, + { + # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors") + "model_hash": "89d48e420f45cff95115a9f3e698d44a", + "model_name": "siglip_vision_model_428m", + "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", + }, + { + # Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors") + "model_hash": "1677708d40029ab380a95f6c731a57d7", + "model_name": "z_image_controlnet", + "model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet", + }, + { + # Example: ??? + "model_hash": "9510cb8cd1dd34ee0e4f111c24905510", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 128}, + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/DiffSynth-Studio/diffsynth/configs/vram_management_module_maps.py b/DiffSynth-Studio/diffsynth/configs/vram_management_module_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..a1813fb4b1d040c1ac7c5cf79bd4ea5136a76e07 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/configs/vram_management_module_maps.py @@ -0,0 +1,213 @@ +flux_general_vram_config = { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule", +} + +VRAM_MANAGEMENT_MODULE_MAPS = { + "diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_vae.QwenImageVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": { + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": { + "diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_dit_s2v.WanS2VModel": { + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_dit.WanModel": { + "diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule", + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_image_encoder.WanImageEncoder": { + "diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_mot.MotWanModel": { + "diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.wan_video_text_encoder.WanTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vace.VaceWanModel": { + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vae.WanVideoVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wan_video_vae.WanVideoVAE38": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.wav2vec.WanS2VAudioEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_dit.FluxDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config, + "diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config, + "diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config, + "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config, + "diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config, + "diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config, + "diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config, + "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config, + "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": { + "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_dit.Flux2DiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_text_encoder.Flux2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_vae.Flux2VAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_text_encoder.ZImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_dit.ZImageDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_controlnet.ZImageControlNet": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, + "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": { + "transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + }, +} diff --git a/DiffSynth-Studio/diffsynth/core/__init__.py b/DiffSynth-Studio/diffsynth/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0a6c8774ba11b6e2dd9d54775d50431acdaaee --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/__init__.py @@ -0,0 +1,6 @@ +from .attention import * +from .data import * +from .gradient import * +from .loader import * +from .vram import * +from .device import * diff --git a/DiffSynth-Studio/diffsynth/core/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1579eb6554a8ff048d693d8b2b6dfe79bbee3484 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/attention/__init__.py b/DiffSynth-Studio/diffsynth/core/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45cf8a4382397aa4ff6558f37191c726d821b95a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/attention/__init__.py @@ -0,0 +1 @@ +from .attention import attention_forward diff --git a/DiffSynth-Studio/diffsynth/core/attention/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/attention/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd36416684b16f13506c5b59311dae7844829662 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/attention/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/attention/__pycache__/attention.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/attention/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b335c7235bb5eba2890c196e2001b1fd55ddf8a8 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/attention/__pycache__/attention.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/attention/attention.py b/DiffSynth-Studio/diffsynth/core/attention/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..15b55a43aa0ee248245fd6652f731f966091b6f7 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/attention/attention.py @@ -0,0 +1,121 @@ +import torch, os +from einops import rearrange + + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +try: + import xformers.ops as xops + XFORMERS_AVAILABLE = True +except ModuleNotFoundError: + XFORMERS_AVAILABLE = False + + +def initialize_attention_priority(): + if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None: + return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower() + elif FLASH_ATTN_3_AVAILABLE: + return "flash_attention_3" + elif FLASH_ATTN_2_AVAILABLE: + return "flash_attention_2" + elif SAGE_ATTN_AVAILABLE: + return "sage_attention" + elif XFORMERS_AVAILABLE: + return "xformers" + else: + return "torch" + + +ATTENTION_IMPLEMENTATION = initialize_attention_priority() + + +def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None): + dims = {} if dims is None else dims + if q_pattern != required_in_pattern: + q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims) + if k_pattern != required_in_pattern: + k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) + if v_pattern != required_in_pattern: + v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) + return q, k, v + + +def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None): + dims = {} if dims is None else dims + if out_pattern != required_out_pattern: + out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims) + return out + + +def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None): + required_in_pattern, required_out_pattern= "b n s d", "b n s d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale) + if isinstance(out, tuple): + out = out[0] + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b n s d", "b n s d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = sageattn(q, k, v, sm_scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): + required_in_pattern, required_out_pattern= "b s n d", "b s n d" + q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) + out = xops.memory_efficient_attention(q, k, v, scale=scale) + out = rearrange_out(out, out_pattern, required_out_pattern, dims) + return out + + +def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False): + if compatibility_mode or (attn_mask is not None): + return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale) + else: + if ATTENTION_IMPLEMENTATION == "flash_attention_3": + return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "flash_attention_2": + return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "sage_attention": + return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + elif ATTENTION_IMPLEMENTATION == "xformers": + return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) + else: + return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) diff --git a/DiffSynth-Studio/diffsynth/core/data/__init__.py b/DiffSynth-Studio/diffsynth/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d494a277d81eeb2a9575155eb983d8bc3879590a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/data/__init__.py @@ -0,0 +1 @@ +from .unified_dataset import UnifiedDataset diff --git a/DiffSynth-Studio/diffsynth/core/data/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27e8f964fca035463e476d67dde6043283eb0a54 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/data/__pycache__/operators.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/data/__pycache__/operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..747ef28b585295cc7f0bca1b33cccf9779d90ea0 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/data/__pycache__/operators.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/data/__pycache__/unified_dataset.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/data/__pycache__/unified_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..305e51f7888c7a2c3cfe62c1eae1bfb67731139f Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/data/__pycache__/unified_dataset.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/data/operators.py b/DiffSynth-Studio/diffsynth/core/data/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..007c53189afe66a53aceb5070693bd01ef4086f7 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/data/operators.py @@ -0,0 +1,238 @@ +import torch, torchvision, imageio, os +import imageio.v3 as iio +from PIL import Image +import numpy as np + +class DataProcessingPipeline: + def __init__(self, operators=None): + self.operators: list[DataProcessingOperator] = [] if operators is None else operators + + def __call__(self, data): + for operator in self.operators: + data = operator(data) + return data + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline(self.operators + pipe.operators) + + +class DataProcessingOperator: + def __call__(self, data): + raise NotImplementedError("DataProcessingOperator cannot be called directly.") + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline([self]).__rshift__(pipe) + + +class DataProcessingOperatorRaw(DataProcessingOperator): + def __call__(self, data): + return data + + +class ToInt(DataProcessingOperator): + def __call__(self, data): + return int(data) + + +class ToFloat(DataProcessingOperator): + def __call__(self, data): + return float(data) + + +class ToStr(DataProcessingOperator): + def __init__(self, none_value=""): + self.none_value = none_value + + def __call__(self, data): + if data is None: data = self.none_value + return str(data) + + +class LoadImage(DataProcessingOperator): + def __init__(self, convert_RGB=True, convert_RGBA=False): + self.convert_RGB = convert_RGB + self.convert_RGBA = convert_RGBA + + def __call__(self, data: str): + image = Image.open(data) + if self.convert_RGB: image = image.convert("RGB") + if self.convert_RGBA: image = image.convert("RGBA") + return image + + +class ImageCropAndResize(DataProcessingOperator): + def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1): + self.height = height + self.width = width + self.max_pixels = max_pixels + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def get_height_width(self, image): + if self.height is None or self.width is None: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + def __call__(self, data: Image.Image): + image = self.crop_and_resize(data, *self.get_height_width(data)) + return image + + +class ToList(DataProcessingOperator): + def __call__(self, data): + return [data] + + + +class LoadVideo(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, reader): + num_frames = self.num_frames + if int(reader.count_frames()) < num_frames: + num_frames = int(reader.count_frames()) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + reader = imageio.get_reader(data) + + add_id = 0 + + frame_indices = np.linspace( + add_id, + add_id+80, + self.num_frames + ).round().astype(int) + + + frames = [] + for idx in frame_indices: + frame = reader.get_data(idx) + frame = Image.fromarray(frame) + frame = self.frame_processor(frame) + frames.append(frame) + + + reader.close() + + + last = frames[-1] + for _ in range(4): + frames.append(last) + + return frames + + +class SequencialProcess(DataProcessingOperator): + def __init__(self, operator=lambda x: x): + self.operator = operator + + def __call__(self, data): + return [self.operator(i) for i in data] + + +class LoadGIF(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, path): + num_frames = self.num_frames + images = iio.imread(path, mode="RGB") + if len(images) < num_frames: + num_frames = len(images) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + num_frames = self.get_num_frames(data) + frames = [] + images = iio.imread(data, mode="RGB") + for img in images: + frame = Image.fromarray(img) + frame = self.frame_processor(frame) + frames.append(frame) + if len(frames) >= num_frames: + break + return frames + + +class RouteByExtensionName(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data: str): + file_ext_name = data.split(".")[-1].lower() + for ext_names, operator in self.operator_map: + if ext_names is None or file_ext_name in ext_names: + return operator(data) + raise ValueError(f"Unsupported file: {data}") + + +class RouteByType(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data): + for dtype, operator in self.operator_map: + if dtype is None or isinstance(data, dtype): + return operator(data) + raise ValueError(f"Unsupported data: {data}") + + +class LoadTorchPickle(DataProcessingOperator): + def __init__(self, map_location="cpu"): + self.map_location = map_location + + def __call__(self, data): + return torch.load(data, map_location=self.map_location, weights_only=False) + + +class ToAbsolutePath(DataProcessingOperator): + def __init__(self, base_path=""): + self.base_path = base_path + + def __call__(self, data): + return os.path.join(self.base_path, data) + + +class LoadAudio(DataProcessingOperator): + def __init__(self, sr=16000): + self.sr = sr + def __call__(self, data: str): + import librosa + input_audio, sample_rate = librosa.load(data, sr=self.sr) + return input_audio diff --git a/DiffSynth-Studio/diffsynth/core/data/unified_dataset.py b/DiffSynth-Studio/diffsynth/core/data/unified_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0897eb615b041d1797ef810150848fd609114b --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/data/unified_dataset.py @@ -0,0 +1,105 @@ +from .operators import * +import torch, json + +def save_video_tensor_as_mp4(video_frames, out_path, fps=8): + + + # (C,T,H,W) -> (T,H,W,C) + video_np = [] + for frame in video_frames: + + frame_np = np.array(frame) + video_np.append(frame_np) + + + video = np.stack(video_np, axis=0) + + imageio.mimwrite( + out_path, + video, + fps=fps, + codec="libx264", + quality=8, + ) + + +class UnifiedDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, + repeat=1, + data_file_keys=tuple(), + main_data_operator=lambda x: x, + ): + self.base_path = base_path + self.repeat = repeat + self.data_file_keys = data_file_keys + self.main_data_operator = main_data_operator + self.data = [] + self.load_metadata() + + @staticmethod + def default_video_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + num_frames=81, time_division_factor=4, time_division_remainder=1, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), + ]) + + + def load_metadata(self): + src_dir = os.path.join(self.base_path, "point_video") + tgt_dir = os.path.join(self.base_path, "videos/train") + + video_exts = (".mp4", ".avi", ".mov", ".mkv", ".webm") + + for fname in os.listdir(src_dir): + if not fname.lower().endswith(video_exts): + continue + + src_path = os.path.join(src_dir, fname) + tgt_path = os.path.join(tgt_dir, fname) + + if not os.path.exists(tgt_path) or os.path.getsize(tgt_path) == 0: + print(f"跳过无效文件:{tgt_path}") + continue + if not os.path.exists(src_path) or os.path.getsize(src_path) == 0: + print(f"跳过无效文件:{src_path}") + continue + + self.data.append({ + "src_video": src_path, + "tgt_video": tgt_path, + "prompt": "Ensure the consistency of the video" + }) + + print(f"Found {len(self.data)} video pairs") + + + + def __getitem__(self, data_id): + + try: + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + data[key] = self.main_data_operator(data[key]) + return data + except Exception: + return self.__getitem__(data_id + 1) + + def __len__(self): + return len(self.data) * self.repeat diff --git a/DiffSynth-Studio/diffsynth/core/data/unified_dataset_old.py b/DiffSynth-Studio/diffsynth/core/data/unified_dataset_old.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae141e7a9b621d5f9be574c67b5e532041b4712 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/data/unified_dataset_old.py @@ -0,0 +1,139 @@ +from .operators import * +import torch, json, pandas + + +class UnifiedDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + repeat=1, + data_file_keys=tuple(), + main_data_operator=lambda x: x, + max_data_items=None, + ): + self.base_path = base_path + self.metadata_path = metadata_path + self.repeat = repeat + self.data_file_keys = data_file_keys + self.main_data_operator = main_data_operator + self.cached_data_operator = LoadTorchPickle() + self.max_data_items = max_data_items + self.data = [] + self.cached_data = [] + self.load_from_cache = metadata_path is None + self.load_metadata(metadata_path) + + @staticmethod + def default_image_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), + ]) + + @staticmethod + def default_video_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + num_frames=81, time_division_factor=4, time_division_remainder=1, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), + ]) + + def search_for_cached_data_files(self, path): + for file_name in os.listdir(path): + subpath = os.path.join(path, file_name) + if os.path.isdir(subpath): + self.search_for_cached_data_files(subpath) + elif subpath.endswith(".pth"): + self.cached_data.append(subpath) + + def load_metadata(self, metadata_path): + + if metadata_path == "folder_pair": + print("Loading paired videos directly from folders") + + src_dir = os.path.join(self.base_path, "src") + tgt_dir = os.path.join(self.base_path, "tgt") + + video_exts = (".mp4", ".avi", ".mov", ".mkv", ".webm") + self.data = [] + + for fname in os.listdir(src_dir): + if not fname.lower().endswith(video_exts): + continue + + src_path = os.path.join(src_dir, fname) + tgt_path = os.path.join(tgt_dir, fname) + + if not os.path.exists(tgt_path): + continue + + self.data.append({ + "src_video": src_path, + "tgt_video": tgt_path, + }) + + print(f"Found {len(self.data)} video pairs") + + + elif metadata_path is None: + print("No metadata_path. Searching for cached data files.") + self.search_for_cached_data_files(self.base_path) + print(f"{len(self.cached_data)} cached data files found.") + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata + elif metadata_path.endswith(".jsonl"): + metadata = [] + with open(metadata_path, 'r') as f: + for line in f: + metadata.append(json.loads(line.strip())) + self.data = metadata + else: + metadata = pandas.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + def __getitem__(self, data_id): + if self.load_from_cache: + data = self.cached_data[data_id % len(self.cached_data)] + data = self.cached_data_operator(data) + else: + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + data[key] = self.main_data_operator(data[key]) + return data + + def __len__(self): + if self.max_data_items is not None: + return self.max_data_items + elif self.load_from_cache: + return len(self.cached_data) * self.repeat + else: + return len(self.data) * self.repeat + + def check_data_equal(self, data1, data2): + # Debug only + if len(data1) != len(data2): + return False + for k in data1: + if data1[k] != data2[k]: + return False + return True diff --git a/DiffSynth-Studio/diffsynth/core/device/__init__.py b/DiffSynth-Studio/diffsynth/core/device/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..889d6823adb203630d2173e8ab25fcc350b32ce4 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/device/__init__.py @@ -0,0 +1,2 @@ +from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name +from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE diff --git a/DiffSynth-Studio/diffsynth/core/device/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/device/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b95079936f5863e7b8fa76473546a56e7dfa3be Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/device/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/device/__pycache__/npu_compatible_device.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/device/__pycache__/npu_compatible_device.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f7600836a2f7b4a5e47b72cd30a7f7b119ea76 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/device/__pycache__/npu_compatible_device.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/device/npu_compatible_device.py b/DiffSynth-Studio/diffsynth/core/device/npu_compatible_device.py new file mode 100644 index 0000000000000000000000000000000000000000..d96b8fb2479688e2209a7c8349d76ba093aa8e99 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/device/npu_compatible_device.py @@ -0,0 +1,107 @@ +import importlib +import torch +from typing import Any + + +def is_torch_npu_available(): + return importlib.util.find_spec("torch_npu") is not None + + +IS_CUDA_AVAILABLE = torch.cuda.is_available() +IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available() + +if IS_NPU_AVAILABLE: + import torch_npu + + torch.npu.config.allow_internal_format = False + + +def get_device_type() -> str: + """Get device type based on current machine, currently only support CPU, CUDA, NPU.""" + if IS_CUDA_AVAILABLE: + device = "cuda" + elif IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + + return device + + +def get_torch_device() -> Any: + """Get torch attribute based on device type, e.g. torch.cuda or torch.npu""" + device_name = get_device_type() + + try: + return getattr(torch, device_name) + except AttributeError: + print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.") + return torch.cuda + + +def get_device_id() -> int: + """Get current device id based on device type.""" + return get_torch_device().current_device() + + +def get_device_name() -> str: + """Get current device name based on device type.""" + return f"{get_device_type()}:{get_device_id()}" + + +def synchronize() -> None: + """Execute torch synchronize operation.""" + get_torch_device().synchronize() + + +def empty_cache() -> None: + """Execute torch empty cache operation.""" + get_torch_device().empty_cache() + + +def get_nccl_backend() -> str: + """Return distributed communication backend type based on device type.""" + if IS_CUDA_AVAILABLE: + return "nccl" + elif IS_NPU_AVAILABLE: + return "hccl" + else: + raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.") + + +def enable_high_precision_for_bf16(): + """ + Set high accumulation dtype for matmul and reduction. + """ + if IS_CUDA_AVAILABLE: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + if IS_NPU_AVAILABLE: + torch.npu.matmul.allow_tf32 = False + torch.npu.matmul.allow_bf16_reduced_precision_reduction = False + + +def parse_device_type(device): + if isinstance(device, str): + if device.startswith("cuda"): + return "cuda" + elif device.startswith("npu"): + return "npu" + else: + return "cpu" + elif isinstance(device, torch.device): + return device.type + + +def parse_nccl_backend(device_type): + if device_type == "cuda": + return "nccl" + elif device_type == "npu": + return "hccl" + else: + raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.") + + +def get_available_device_type(): + return get_device_type() diff --git a/DiffSynth-Studio/diffsynth/core/gradient/__init__.py b/DiffSynth-Studio/diffsynth/core/gradient/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57914792a78ec32f69c3c99ae37535598efc8d52 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/gradient/__init__.py @@ -0,0 +1 @@ +from .gradient_checkpoint import gradient_checkpoint_forward diff --git a/DiffSynth-Studio/diffsynth/core/gradient/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/gradient/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d875e3db482e64509a70a1afd30b923bf574fe8f Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/gradient/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/gradient/__pycache__/gradient_checkpoint.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/gradient/__pycache__/gradient_checkpoint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb7ff30b74322c0ac1f5194dc8a3567f151e0ab3 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/gradient/__pycache__/gradient_checkpoint.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/gradient/gradient_checkpoint.py b/DiffSynth-Studio/diffsynth/core/gradient/gradient_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b356415a004f3d74afdd45840f1fc4caf6659e16 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/gradient/gradient_checkpoint.py @@ -0,0 +1,34 @@ +import torch + + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + else: + model_output = model(*args, **kwargs) + return model_output diff --git a/DiffSynth-Studio/diffsynth/core/loader/__init__.py b/DiffSynth-Studio/diffsynth/core/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f56d814bae40436f66bca583e33d180d6e11247 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/loader/__init__.py @@ -0,0 +1,3 @@ +from .file import load_state_dict, hash_state_dict_keys, hash_model_file +from .model import load_model, load_model_with_disk_offload +from .config import ModelConfig diff --git a/DiffSynth-Studio/diffsynth/core/loader/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b185eb9ca2e8369c3e977561208a68fbf0f15459 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/loader/__pycache__/config.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af44e7c41b0d4fd887b3d6f3586e566ca151c0f Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/config.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/loader/__pycache__/file.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/file.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c673427b095f3a38404e05476286e20cc5fcee6 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/file.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/loader/__pycache__/model.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63248a67bf3d1e35e0cc3509d639b5e06a28f183 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/loader/__pycache__/model.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/loader/config.py b/DiffSynth-Studio/diffsynth/core/loader/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cbec737c532c8b7ea76c572eb250c4b2f33b655d --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/loader/config.py @@ -0,0 +1,122 @@ +import torch, glob, os +from typing import Optional, Union +from dataclasses import dataclass +# from modelscope import snapshot_download +from huggingface_hub import snapshot_download as hf_snapshot_download +from typing import Optional + + +@dataclass +class ModelConfig: + path: Union[str, list[str]] = None + model_id: str = None + origin_file_pattern: Union[str, list[str]] = None + download_source: str = None + local_model_path: str = None + skip_download: bool = None + offload_device: Optional[Union[str, torch.device]] = None + offload_dtype: Optional[torch.dtype] = None + onload_device: Optional[Union[str, torch.device]] = None + onload_dtype: Optional[torch.dtype] = None + preparing_device: Optional[Union[str, torch.device]] = None + preparing_dtype: Optional[torch.dtype] = None + computation_device: Optional[Union[str, torch.device]] = None + computation_dtype: Optional[torch.dtype] = None + clear_parameters: bool = False + + def check_input(self): + if self.path is None and self.model_id is None: + raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") + + def parse_original_file_pattern(self): + if self.origin_file_pattern is None or self.origin_file_pattern == "": + return "*" + elif self.origin_file_pattern.endswith("/"): + return self.origin_file_pattern + "*" + else: + return self.origin_file_pattern + + def parse_download_source(self): + if self.download_source is None: + if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None: + return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') + else: + return "modelscope" + else: + return self.download_source + + def parse_skip_download(self): + if self.skip_download is None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None: + if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true": + return True + elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false": + return False + else: + return False + else: + return self.skip_download + + def download(self): + origin_file_pattern = self.parse_original_file_pattern() + downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) + download_source = self.parse_download_source() + + + # if download_source.lower() == "modelscope": + # snapshot_download( + # self.model_id, + # local_dir=os.path.join(self.local_model_path, self.model_id), + # allow_file_pattern=origin_file_pattern, + # ignore_file_pattern=downloaded_files, + # local_files_only=False + # ) + # elif + + if download_source.lower() == "huggingface": + hf_snapshot_download( + self.model_id, + local_dir=os.path.join(self.local_model_path, self.model_id), + allow_patterns=origin_file_pattern, + ignore_patterns=downloaded_files, + local_files_only=False + ) + else: + raise ValueError("`download_source` should be `modelscope` or `huggingface`.") + + def require_downloading(self): + if self.path is not None: + return False + skip_download = self.parse_skip_download() + return not skip_download + + def reset_local_model_path(self): + if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None: + self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') + elif self.local_model_path is None: + self.local_model_path = "./models" + + def download_if_necessary(self): + self.check_input() + self.reset_local_model_path() + if self.require_downloading(): + self.download() + if self.path is None: + if self.origin_file_pattern is None or self.origin_file_pattern == "": + self.path = os.path.join(self.local_model_path, self.model_id) + else: + self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) + if isinstance(self.path, list) and len(self.path) == 1: + self.path = self.path[0] + + def vram_config(self): + return { + "offload_device": self.offload_device, + "offload_dtype": self.offload_dtype, + "onload_device": self.onload_device, + "onload_dtype": self.onload_dtype, + "preparing_device": self.preparing_device, + "preparing_dtype": self.preparing_dtype, + "computation_device": self.computation_device, + "computation_dtype": self.computation_dtype, + } diff --git a/DiffSynth-Studio/diffsynth/core/loader/file.py b/DiffSynth-Studio/diffsynth/core/loader/file.py new file mode 100644 index 0000000000000000000000000000000000000000..8f66961f25d4fc547a2ec638f9d6a93be851afb9 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/loader/file.py @@ -0,0 +1,121 @@ +from safetensors import safe_open +import torch, hashlib + + +def load_state_dict(file_path, torch_dtype=None, device="cpu"): + if isinstance(file_path, list): + state_dict = {} + for file_path_ in file_path: + state_dict.update(load_state_dict(file_path_, torch_dtype, device)) + return state_dict + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): + state_dict = {} + with safe_open(file_path, framework="pt", device=str(device)) as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): + state_dict = torch.load(file_path, map_location=device, weights_only=True) + if len(state_dict) == 1: + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + elif "module" in state_dict: + state_dict = state_dict["module"] + elif "model_state" in state_dict: + state_dict = state_dict["model_state"] + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + + +def load_keys_dict(file_path): + if isinstance(file_path, list): + state_dict = {} + for file_path_ in file_path: + state_dict.update(load_keys_dict(file_path_)) + return state_dict + if file_path.endswith(".safetensors"): + return load_keys_dict_from_safetensors(file_path) + else: + return load_keys_dict_from_bin(file_path) + + +def load_keys_dict_from_safetensors(file_path): + keys_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + keys_dict[k] = f.get_slice(k).get_shape() + return keys_dict + + +def convert_state_dict_to_keys_dict(state_dict): + keys_dict = {} + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + keys_dict[k] = list(v.shape) + else: + keys_dict[k] = convert_state_dict_to_keys_dict(v) + return keys_dict + + +def load_keys_dict_from_bin(file_path): + state_dict = load_state_dict_from_bin(file_path) + keys_dict = convert_state_dict_to_keys_dict(state_dict) + return keys_dict + + +def convert_keys_dict_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, dict): + keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape)) + else: + if with_shape: + shape = "_".join(map(str, list(value))) + keys.append(key + ":" + shape) + keys.append(key) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def hash_model_file(path, with_shape=True): + keys_dict = load_keys_dict(path) + keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() diff --git a/DiffSynth-Studio/diffsynth/core/loader/model.py b/DiffSynth-Studio/diffsynth/core/loader/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bf90a23008cee758513d3453d517014ddf5cc953 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/loader/model.py @@ -0,0 +1,83 @@ +from ..vram.initialization import skip_model_initialization +from ..vram.disk_map import DiskMap +from ..vram.layers import enable_vram_management +from .file import load_state_dict +import torch + + +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): + config = {} if config is None else config + # Why do we use `skip_model_initialization`? + # It skips the random initialization of model parameters, + # thereby speeding up model loading and avoiding excessive memory usage. + with skip_model_initialization(): + model = model_class(**config) + # What is `module_map`? + # This is a module mapping table for VRAM management. + if module_map is not None: + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] + device = [d for d in devices if d != "disk"][0] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] + dtype = [d for d in dtypes if d != "disk"][0] + if vram_config["offload_device"] != "disk": + state_dict = DiskMap(path, device, torch_dtype=dtype) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} + model.load_state_dict(state_dict, assign=True) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) + else: + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) + else: + # Why do we use `DiskMap`? + # Sometimes a model file contains multiple models, + # and DiskMap can load only the parameters of a single model, + # avoiding the need to load all parameters in the file. + if use_disk_map: + state_dict = DiskMap(path, device, torch_dtype=torch_dtype) + else: + state_dict = load_state_dict(path, torch_dtype, device) + # Why do we use `state_dict_converter`? + # Some models are saved in complex formats, + # and we need to convert the state dict into the appropriate format. + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + else: + state_dict = {i: state_dict[i] for i in state_dict} + model.load_state_dict(state_dict, assign=True, strict=False) + # Why do we call `to()`? + # Because some models override the behavior of `to()`, + # especially those from libraries like Transformers. + if any(p.is_meta for p in model.parameters()): + model = model.to_empty(device=device) + model = model.to(dtype=torch_dtype) + else: + model = model.to(dtype=torch_dtype, device=device) + if hasattr(model, "eval"): + model = model.eval() + return model + + +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): + if isinstance(path, str): + path = [path] + config = {} if config is None else config + with skip_model_initialization(): + model = model_class(**config) + if hasattr(model, "eval"): + model = model.eval() + disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) + vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": device, + "computation_dtype": torch_dtype, + "computation_device": device, + } + enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) + return model diff --git a/DiffSynth-Studio/diffsynth/core/vram/__init__.py b/DiffSynth-Studio/diffsynth/core/vram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32763bb9b4abfa4d5b2617827661c520d7e9fcae --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/vram/__init__.py @@ -0,0 +1,2 @@ +from .initialization import skip_model_initialization +from .layers import * diff --git a/DiffSynth-Studio/diffsynth/core/vram/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e74b4ca3fabd9f123d8d29e8a51c89298fbb2679 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/vram/__pycache__/disk_map.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/disk_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05f84f0179d8dce47cf2d783a66dedb1ee278ee Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/disk_map.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/vram/__pycache__/initialization.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/initialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06dc0bef5e1ae473d141739adb9e528f96af4fcb Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/initialization.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/vram/__pycache__/layers.cpython-39.pyc b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8610b95b11079309b5bf7553714a12c9cc56935b Binary files /dev/null and b/DiffSynth-Studio/diffsynth/core/vram/__pycache__/layers.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/core/vram/disk_map.py b/DiffSynth-Studio/diffsynth/core/vram/disk_map.py new file mode 100644 index 0000000000000000000000000000000000000000..a666590fa99a9cc4de05dc3f5fa84c212e43de38 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/vram/disk_map.py @@ -0,0 +1,93 @@ +from safetensors import safe_open +import torch, os + + +class SafetensorsCompatibleTensor: + def __init__(self, tensor): + self.tensor = tensor + + def get_shape(self): + return list(self.tensor.shape) + + +class SafetensorsCompatibleBinaryLoader: + def __init__(self, path, device): + print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.") + self.state_dict = torch.load(path, weights_only=True, map_location=device) + + def keys(self): + return self.state_dict.keys() + + def get_tensor(self, name): + return self.state_dict[name] + + def get_slice(self, name): + return SafetensorsCompatibleTensor(self.state_dict[name]) + + +class DiskMap: + + def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): + self.path = path if isinstance(path, list) else [path] + self.device = device + self.torch_dtype = torch_dtype + if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: + self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE')) + else: + self.buffer_size = buffer_size + self.files = [] + self.flush_files() + self.name_map = {} + for file_id, file in enumerate(self.files): + for name in file.keys(): + self.name_map[name] = file_id + self.rename_dict = self.fetch_rename_dict(state_dict_converter) + + def flush_files(self): + if len(self.files) == 0: + for path in self.path: + if path.endswith(".safetensors"): + self.files.append(safe_open(path, framework="pt", device=str(self.device))) + else: + self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device)) + else: + for i, path in enumerate(self.path): + if path.endswith(".safetensors"): + self.files[i] = safe_open(path, framework="pt", device=str(self.device)) + self.num_params = 0 + + def __getitem__(self, name): + if self.rename_dict is not None: name = self.rename_dict[name] + file_id = self.name_map[name] + param = self.files[file_id].get_tensor(name) + if self.torch_dtype is not None and isinstance(param, torch.Tensor): + param = param.to(self.torch_dtype) + if isinstance(param, torch.Tensor) and param.device == "cpu": + param = param.clone() + if isinstance(param, torch.Tensor): + self.num_params += param.numel() + if self.num_params > self.buffer_size: + self.flush_files() + return param + + def fetch_rename_dict(self, state_dict_converter): + if state_dict_converter is None: + return None + state_dict = {} + for file in self.files: + for name in file.keys(): + state_dict[name] = name + state_dict = state_dict_converter(state_dict) + return state_dict + + def __iter__(self): + if self.rename_dict is not None: + return self.rename_dict.__iter__() + else: + return self.name_map.__iter__() + + def __contains__(self, x): + if self.rename_dict is not None: + return x in self.rename_dict + else: + return x in self.name_map diff --git a/DiffSynth-Studio/diffsynth/core/vram/initialization.py b/DiffSynth-Studio/diffsynth/core/vram/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..bff2498b526638bfdd1c114c78aa0b98c251a47d --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/vram/initialization.py @@ -0,0 +1,21 @@ +import torch +from contextlib import contextmanager + + +@contextmanager +def skip_model_initialization(device=torch.device("meta")): + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + old_register_parameter = torch.nn.Module.register_parameter + torch.nn.Module.register_parameter = register_empty_parameter + try: + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter diff --git a/DiffSynth-Studio/diffsynth/core/vram/layers.py b/DiffSynth-Studio/diffsynth/core/vram/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0f99b0d162edfc6e8e7859e28d97de069b44b9ee --- /dev/null +++ b/DiffSynth-Studio/diffsynth/core/vram/layers.py @@ -0,0 +1,479 @@ +import torch, copy +from typing import Union +from .initialization import skip_model_initialization +from .disk_map import DiskMap +from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE + + +class AutoTorchModule(torch.nn.Module): + + def __init__( + self, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + ): + super().__init__() + self.set_dtype_and_device( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.state = 0 + self.name = "" + self.computation_device_type = parse_device_type(self.computation_device) + + def set_dtype_and_device( + self, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + ): + self.offload_dtype = offload_dtype or computation_dtype + self.offload_device = offload_device or computation_dtype + self.onload_dtype = onload_dtype or computation_dtype + self.onload_device = onload_device or computation_dtype + self.preparing_dtype = preparing_dtype or computation_dtype + self.preparing_device = preparing_device or computation_dtype + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.vram_limit = vram_limit + + def cast_to(self, weight, dtype, device): + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def check_free_vram(self): + device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name() + gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device) + used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) + return used_memory < self.vram_limit + + def offload(self): + if self.state != 0: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state != 1: + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def param_name(self, name): + if self.name == "": + return name + else: + return self.name + "." + name + + +class AutoWrappedModule(AutoTorchModule): + + def __init__( + self, + module: torch.nn.Module, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + super().__init__( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.module = module + if offload_dtype == "disk": + self.name = name + self.disk_map = disk_map + self.required_params = [name for name, _ in self.module.named_parameters()] + self.disk_offload = True + else: + self.disk_offload = False + + def load_from_disk(self, torch_dtype, device, copy_module=False): + if copy_module: + module = copy.deepcopy(self.module) + else: + module = self.module + state_dict = {} + for name in self.required_params: + param = self.disk_map[self.param_name(name)] + param = param.to(dtype=torch_dtype, device=device) + state_dict[name] = param + module.load_state_dict(state_dict, assign=True) + module.to(dtype=torch_dtype, device=device) + return module + + def offload_to_disk(self, model: torch.nn.Module): + for buf in model.buffers(): + # If there are some parameters are registed in buffers (not in state dict), + # We cannot offload the model. + for children in model.children(): + self.offload_to_disk(children) + break + else: + model.to("meta") + + def offload(self): + # offload / onload / preparing -> offload + if self.state != 0: + if self.disk_offload: + self.offload_to_disk(self.module) + else: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + # offload / onload / preparing -> onload + if self.state < 1: + if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": + self.load_from_disk(self.onload_dtype, self.onload_device) + elif self.onload_device != "disk": + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def preparing(self): + # onload / preparing -> preparing + if self.state != 2: + if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": + self.load_from_disk(self.preparing_dtype, self.preparing_device) + elif self.preparing_device != "disk": + self.to(dtype=self.preparing_dtype, device=self.preparing_device) + self.state = 2 + + def cast_to(self, module, dtype, device): + return copy.deepcopy(module).to(dtype=dtype, device=device) + + def computation(self): + # onload / preparing -> computation (temporary) + if self.state == 2: + torch_dtype, device = self.preparing_dtype, self.preparing_device + else: + torch_dtype, device = self.onload_dtype, self.onload_device + if torch_dtype == self.computation_dtype and device == self.computation_device: + module = self.module + elif self.disk_offload and device == "disk": + module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True) + else: + module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device) + return module + + def forward(self, *args, **kwargs): + if self.state == 1 and (self.vram_limit is None or self.check_free_vram()): + self.preparing() + module = self.computation() + return module(*args, **kwargs) + + def __getattr__(self, name): + if name in self.__dict__ or name == "module": + return super().__getattr__(name) + else: + return getattr(self.module, name) + + +class AutoWrappedNonRecurseModule(AutoWrappedModule): + + def __init__( + self, + module: torch.nn.Module, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + super().__init__( + module, + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + name, + disk_map, + **kwargs + ) + if self.disk_offload: + self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)] + + def load_from_disk(self, torch_dtype, device, copy_module=False): + if copy_module: + module = copy.deepcopy(self.module) + else: + module = self.module + state_dict = {} + for name in self.required_params: + param = self.disk_map[self.param_name(name)] + param = param.to(dtype=torch_dtype, device=device) + state_dict[name] = param + module.load_state_dict(state_dict, assign=True, strict=False) + return module + + def offload_to_disk(self, model: torch.nn.Module): + for name in self.required_params: + getattr(self, name).to("meta") + + def cast_to(self, module, dtype, device): + # Parameter casting is implemented in the model architecture. + return module + + def __getattr__(self, name): + if name in self.__dict__ or name == "module": + return super().__getattr__(name) + else: + return getattr(self.module, name) + + +class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): + def __init__( + self, + module: torch.nn.Linear, + offload_dtype: torch.dtype = None, + offload_device: Union[str, torch.device] = None, + onload_dtype: torch.dtype = None, + onload_device: Union[str, torch.device] = None, + preparing_dtype: torch.dtype = None, + preparing_device: Union[str, torch.device] = None, + computation_dtype: torch.dtype = None, + computation_device: Union[str, torch.device] = None, + vram_limit: float = None, + name: str = "", + disk_map: DiskMap = None, + **kwargs + ): + with skip_model_initialization(): + super().__init__( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + ) + self.set_dtype_and_device( + offload_dtype, + offload_device, + onload_dtype, + onload_device, + preparing_dtype, + preparing_device, + computation_dtype, + computation_device, + vram_limit, + ) + self.weight = module.weight + self.bias = module.bias + self.state = 0 + self.name = name + self.lora_A_weights = [] + self.lora_B_weights = [] + self.lora_merger = None + self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + self.computation_device_type = parse_device_type(self.computation_device) + + if offload_dtype == "disk": + self.disk_map = disk_map + self.disk_offload = True + else: + self.disk_offload = False + + def fp8_linear( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + ) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + + x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values + fp8_max = 448.0 + # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. + # To avoid overflow and ensure numerical compatibility during FP8 computation, + # we scale down the input by 2.0 in advance. + # This scaling will be compensated later during the final result scaling. + if self.computation_dtype == torch.float8_e4m3fnuz: + fp8_max = fp8_max / 2.0 + scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) + scale_b = torch.ones((weight.shape[0], 1)).to(device=device) + input = input / (scale_a + 1e-8) + input = input.to(self.computation_dtype) + weight = weight.to(self.computation_dtype) + bias = bias.to(torch.bfloat16) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result + + def load_from_disk(self, torch_dtype, device, assign=True): + weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device) + bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device) + if assign: + state_dict = {"weight": weight} + if bias is not None: state_dict["bias"] = bias + self.load_state_dict(state_dict, assign=True) + return weight, bias + + def offload(self): + # offload / onload / preparing -> offload + if self.state != 0: + if self.disk_offload: + self.to("meta") + else: + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + # offload / onload / preparing -> onload + if self.state < 1: + if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": + self.load_from_disk(self.onload_dtype, self.onload_device) + elif self.onload_device != "disk": + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def preparing(self): + # onload / preparing -> preparing + if self.state != 2: + if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": + self.load_from_disk(self.preparing_dtype, self.preparing_device) + elif self.preparing_device != "disk": + self.to(dtype=self.preparing_dtype, device=self.preparing_device) + self.state = 2 + + def computation(self): + # onload / preparing -> computation (temporary) + if self.state == 2: + torch_dtype, device = self.preparing_dtype, self.preparing_device + else: + torch_dtype, device = self.onload_dtype, self.onload_device + if torch_dtype == self.computation_dtype and device == self.computation_device: + weight, bias = self.weight, self.bias + elif self.disk_offload and device == "disk": + weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False) + else: + weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device) + return weight, bias + + def linear_forward(self, x, weight, bias): + if self.enable_fp8: + out = self.fp8_linear(x, weight, bias) + else: + out = torch.nn.functional.linear(x, weight, bias) + return out + + def lora_forward(self, x, out): + if self.lora_merger is None: + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + out = out + x @ lora_A.T @ lora_B.T + else: + lora_output = [] + for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights): + lora_output.append(x @ lora_A.T @ lora_B.T) + lora_output = torch.stack(lora_output) + out = self.lora_merger(out, lora_output) + return out + + def forward(self, x, *args, **kwargs): + if self.state == 1 and (self.vram_limit is None or self.check_free_vram()): + self.preparing() + weight, bias = self.computation() + out = self.linear_forward(x, weight, bias) + if len(self.lora_A_weights) > 0: + out = self.lora_forward(x, out) + return out + + +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs): + if isinstance(model, AutoWrappedNonRecurseModule): + model = model.module + for name, module in model.named_children(): + layer_name = name if name_prefix == "" else name_prefix + "." + name + for source_module, target_module in module_map.items(): + if isinstance(module, source_module): + module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs) + if isinstance(module_, AutoWrappedNonRecurseModule): + enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs) + setattr(model, name, module_) + break + else: + enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs) + + +def fill_vram_config(model, vram_config): + vram_config_ = vram_config.copy() + vram_config_["onload_dtype"] = vram_config["computation_dtype"] + vram_config_["onload_device"] = vram_config["computation_device"] + vram_config_["preparing_dtype"] = vram_config["computation_dtype"] + vram_config_["preparing_device"] = vram_config["computation_device"] + for k in vram_config: + if vram_config[k] != vram_config_[k]: + print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}") + break + return vram_config_ + + +def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs): + for source_module, target_module in module_map.items(): + # If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly. + if isinstance(model, source_module): + vram_config = fill_vram_config(model, vram_config) + model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs) + break + else: + enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs) + # `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled. + model.vram_management_enabled = True + return model diff --git a/DiffSynth-Studio/diffsynth/diffusion/__init__.py b/DiffSynth-Studio/diffsynth/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4a0873a7b3d09e95aa00cfe340d653c58a834b --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/__init__.py @@ -0,0 +1,6 @@ +from .flow_match import FlowMatchScheduler +from .training_module import DiffusionTrainingModule +from .logger import ModelLogger +from .runner import launch_training_task, launch_data_process_task +from .parsers import * +from .loss import * diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fd428da6018133948781acb3ac393a882fcd4a2 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/base_pipeline.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/base_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3087a527565e1426764a3b72fb952ff1121b90c3 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/base_pipeline.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/flow_match.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/flow_match.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b3d34fdcad13aa0f3365a6cf3b92a33e4f367c Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/flow_match.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/logger.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99709fa9f32a159c423e4213f5fa6610e7940986 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/logger.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/loss.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807ba6cf79ff322fbbcad9786ba731bee9004c95 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/loss.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/parsers.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/parsers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3587e08bd97b72182ab5f20f77e04bac66c2506 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/parsers.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/runner.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/runner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..551b81ccb12799a18fa05694e97b0ac12f9c5f99 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/runner.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/__pycache__/training_module.cpython-39.pyc b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/training_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d81685a7a827b9eae15adb63f06c406542fa705 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/diffusion/__pycache__/training_module.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/diffusion/base_pipeline.py b/DiffSynth-Studio/diffsynth/diffusion/base_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d4731fd18d62ed171fe5d993faedd7656ff22bab --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/base_pipeline.py @@ -0,0 +1,451 @@ +from PIL import Image +import torch +import numpy as np +from einops import repeat, reduce +from typing import Union +from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type +from ..core.device.npu_compatible_device import get_device_type +from ..utils.lora import GeneralLoRALoader +from ..models.model_loader import ModelPool +from ..utils.controlnet import ControlNetInput +from ..core.device import get_device_name, IS_NPU_AVAILABLE + + +class PipelineUnit: + def __init__( + self, + seperate_cfg: bool = False, + take_over: bool = False, + input_params: tuple[str] = None, + output_params: tuple[str] = None, + input_params_posi: dict[str, str] = None, + input_params_nega: dict[str, str] = None, + onload_model_names: tuple[str] = None + ): + self.seperate_cfg = seperate_cfg + self.take_over = take_over + self.input_params = input_params + self.output_params = output_params + self.input_params_posi = input_params_posi + self.input_params_nega = input_params_nega + self.onload_model_names = onload_model_names + + def fetch_input_params(self): + params = [] + if self.input_params is not None: + for param in self.input_params: + params.append(param) + if self.input_params_posi is not None: + for _, param in self.input_params_posi.items(): + params.append(param) + if self.input_params_nega is not None: + for _, param in self.input_params_nega.items(): + params.append(param) + params = sorted(list(set(params))) + return params + + def fetch_output_params(self): + params = [] + if self.output_params is not None: + for param in self.output_params: + params.append(param) + return params + + def process(self, pipe, **kwargs) -> dict: + return {} + + def post_process(self, pipe, **kwargs) -> dict: + return {} + + +class BasePipeline(torch.nn.Module): + + def __init__( + self, + device=get_device_type(), torch_dtype=torch.float16, + height_division_factor=64, width_division_factor=64, + time_division_factor=None, time_division_remainder=None, + ): + super().__init__() + # The device and torch_dtype is used for the storage of intermediate variables, not models. + self.device = device + self.torch_dtype = torch_dtype + self.device_type = parse_device_type(device) + # The following parameters are used for shape check. + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # VRAM management + self.vram_management_enabled = False + # Pipeline Unit Runner + self.unit_runner = PipelineUnitRunner() + # LoRA Loader + self.lora_loader = GeneralLoRALoader + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def check_resize_height_width(self, height, width, num_frames=None): + # Shape check + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if num_frames is None: + return height, width + else: + if num_frames % self.time_division_factor != self.time_division_remainder: + num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + return height, width, num_frames + + + def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): + # Transform a PIL.Image to torch.Tensor + image = torch.Tensor(np.array(image, dtype=np.float32)) + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + image = image * ((max_value - min_value) / 255) + min_value + image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) + return image + + + def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a list of PIL.Image to torch.Tensor + video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] + video = torch.stack(video, dim=pattern.index("T") // 2) + return video + + + def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to PIL.Image + if pattern != "H W C": + vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") + image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) + image = image.to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(image.numpy()) + return image + + + def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to list of PIL.Image + if pattern != "T H W C": + vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") + video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] + return video + + + def load_models_to_device(self, model_names): + if self.vram_management_enabled: + # offload models + for name, model in self.named_children(): + if name not in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + if hasattr(model, "offload"): + model.offload() + else: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + getattr(torch, self.device_type).empty_cache() + # onload models + for name, model in self.named_children(): + if name in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + if hasattr(model, "onload"): + model.onload() + else: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + + + def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): + # Initialize Gaussian noise + generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) + noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + return noise + + + def get_vram(self): + device = self.device if not IS_NPU_AVAILABLE else get_device_name() + return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3) + + def get_module(self, model, name): + if "." in name: + name, suffix = name[:name.index(".")], name[name.index(".") + 1:] + if name.isdigit(): + return self.get_module(model[int(name)], suffix) + else: + return self.get_module(getattr(model, name), suffix) + else: + return getattr(model, name) + + def freeze_except(self, model_names): + self.eval() + self.requires_grad_(False) + for name in model_names: + module = self.get_module(self, name) + if module is None: + print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.") + continue + module.train() + module.requires_grad_(True) + + + def blend_with_mask(self, base, addition, mask): + return base * (1 - mask) + addition * mask + + + def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs): + timestep = scheduler.timesteps[progress_id] + if inpaint_mask is not None: + noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents) + noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask) + latents_next = scheduler.step(noise_pred, timestep, latents) + return latents_next + + + def split_pipeline_units(self, model_names: list[str]): + return PipelineUnitGraph().split_pipeline_units(self.units, model_names) + + + def flush_vram_management_device(self, device): + for module in self.modules(): + if isinstance(module, AutoTorchModule): + module.offload_device = device + module.onload_device = device + module.preparing_device = device + module.computation_device = device + + + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str] = None, + alpha=1, + hotload=None, + state_dict=None, + verbose=1, + ): + if state_dict is None: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + else: + lora = state_dict + lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + lora = lora_loader.convert_state_dict(lora) + if hotload is None: + hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled") + if hotload: + if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")): + raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.") + updated_num = 0 + for _, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + name = module.name + lora_a_name = f'{name}.lora_A.weight' + lora_b_name = f'{name}.lora_B.weight' + if lora_a_name in lora and lora_b_name in lora: + updated_num += 1 + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + if verbose >= 1: + print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.") + else: + lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha) + + + def clear_lora(self, verbose=1): + cleared_num = 0 + for name, module in self.named_modules(): + if isinstance(module, AutoWrappedLinear): + if hasattr(module, "lora_A_weights"): + if len(module.lora_A_weights) > 0: + cleared_num += 1 + module.lora_A_weights.clear() + if hasattr(module, "lora_B_weights"): + module.lora_B_weights.clear() + if verbose >= 1: + print(f"{cleared_num} LoRA layers are cleared.") + + + def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None): + model_pool = ModelPool() + for model_config in model_configs: + model_config.download_if_necessary() + vram_config = model_config.vram_config() + vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype + vram_config["computation_device"] = vram_config["computation_device"] or self.device + model_pool.auto_load_model( + model_config.path, + vram_config=vram_config, + vram_limit=vram_limit, + clear_parameters=model_config.clear_parameters, + ) + return model_pool + + + def check_vram_management_state(self): + vram_management_enabled = False + for module in self.children(): + if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"): + vram_management_enabled = True + return vram_management_enabled + + + def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) + noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + return noise_pred + + +class PipelineUnitGraph: + def __init__(self): + pass + + def build_edges(self, units: list[PipelineUnit]): + # Establish dependencies between units + # to search for subsequent related computation units. + last_compute_unit_id = {} + edges = [] + for unit_id, unit in enumerate(units): + for input_param in unit.fetch_input_params(): + if input_param in last_compute_unit_id: + edges.append((last_compute_unit_id[input_param], unit_id)) + for output_param in unit.fetch_output_params(): + last_compute_unit_id[output_param] = unit_id + return edges + + def build_chains(self, units: list[PipelineUnit]): + # Establish updating chains for each variable + # to track their computation process. + params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], []) + params = sorted(list(set(params))) + chains = {param: [] for param in params} + for unit_id, unit in enumerate(units): + for param in unit.fetch_output_params(): + chains[param].append(unit_id) + return chains + + def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]): + # Search for units that directly participate in the model's computation. + related_unit_ids = [] + for unit_id, unit in enumerate(units): + for model_name in model_names: + if unit.onload_model_names is not None and model_name in unit.onload_model_names: + related_unit_ids.append(unit_id) + break + return related_unit_ids + + def search_related_unit_ids(self, edges, start_unit_ids, direction="target"): + # Search for subsequent related computation units. + related_unit_ids = [unit_id for unit_id in start_unit_ids] + while True: + neighbors = [] + for source, target in edges: + if direction == "target" and source in related_unit_ids and target not in related_unit_ids: + neighbors.append(target) + elif direction == "source" and source not in related_unit_ids and target in related_unit_ids: + neighbors.append(source) + neighbors = sorted(list(set(neighbors))) + if len(neighbors) == 0: + break + else: + related_unit_ids.extend(neighbors) + related_unit_ids = sorted(list(set(related_unit_ids))) + return related_unit_ids + + def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids): + # If the input parameters of this subgraph are updated outside the subgraph, + # search for the units where these updates occur. + first_compute_unit_id = {} + for unit_id in related_unit_ids: + for param in units[unit_id].fetch_input_params(): + if param not in first_compute_unit_id: + first_compute_unit_id[param] = unit_id + updating_unit_ids = [] + for param in first_compute_unit_id: + unit_id = first_compute_unit_id[param] + chain = chains[param] + if unit_id in chain and chain.index(unit_id) != len(chain) - 1: + for unit_id_ in chain[chain.index(unit_id) + 1:]: + if unit_id_ not in related_unit_ids: + updating_unit_ids.append(unit_id_) + related_unit_ids.extend(updating_unit_ids) + related_unit_ids = sorted(list(set(related_unit_ids))) + return related_unit_ids + + def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]): + # Split the computation graph, + # separating all model-related computations. + related_unit_ids = self.search_direct_unit_ids(units, model_names) + edges = self.build_edges(units) + chains = self.build_chains(units) + while True: + num_related_unit_ids = len(related_unit_ids) + related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target") + related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids) + if len(related_unit_ids) == num_related_unit_ids: + break + else: + num_related_unit_ids = len(related_unit_ids) + related_units = [units[i] for i in related_unit_ids] + unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids] + return related_units, unrelated_units + + +class PipelineUnitRunner: + def __init__(self): + pass + + def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: + # Let the pipeline unit take over this function. + inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) + elif unit.seperate_cfg: + # Positive side + processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_posi.update(processor_outputs) + # Negative side + if inputs_shared["cfg_scale"] != 1: + processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + if unit.input_params is not None: + for name in unit.input_params: + processor_inputs[name] = inputs_shared.get(name) + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_nega.update(processor_outputs) + else: + inputs_nega.update(processor_outputs) + else: + processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_shared.update(processor_outputs) + return inputs_shared, inputs_posi, inputs_nega diff --git a/DiffSynth-Studio/diffsynth/diffusion/flow_match.py b/DiffSynth-Studio/diffsynth/diffusion/flow_match.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6b3676cab40ea13d10828f2e14b1d297c9b4ec --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/flow_match.py @@ -0,0 +1,184 @@ +import torch, math +from typing_extensions import Literal + + +class FlowMatchScheduler(): + + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): + self.set_timesteps_fn = { + "FLUX.1": FlowMatchScheduler.set_timesteps_flux, + "Wan": FlowMatchScheduler.set_timesteps_wan, + "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, + "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, + "Z-Image": FlowMatchScheduler.set_timesteps_z_image, + }.get(template, FlowMatchScheduler.set_timesteps_flux) + self.num_train_timesteps = 1000 + + @staticmethod + def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.003/1.002 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 5 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + @staticmethod + def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): + sigma_min = 0.0 + sigma_max = 1.0 + num_train_timesteps = 1000 + shift_terminal = 0.02 + # Sigmas + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + # Mu + if exponential_shift_mu is not None: + mu = exponential_shift_mu + elif dynamic_shift_len is not None: + mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len) + else: + mu = 0.8 + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1 - sigmas + scale_factor = one_minus_z[-1] / (1 - shift_terminal) + sigmas = 1 - (one_minus_z / scale_factor) + # Timesteps + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def compute_empirical_mu(image_seq_len, num_steps): + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + @staticmethod + def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None): + sigma_min = 1 / num_inference_steps + sigma_max = 1.0 + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + if dynamic_shift_len is None: + # If you ask me why I set mu=0.8, + # I can only say that it yields better training results. + mu = 0.8 + else: + mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + if target_timesteps is not None: + target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device) + for timestep in target_timesteps: + timestep_id = torch.argmin((timesteps - timestep).abs()) + timesteps[timestep_id] = timestep + return sigmas, timesteps + + def set_training_weight(self): + steps = 1000 + x = self.timesteps + y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) + if len(self.timesteps) != 1000: + # This is an empirical formula. + bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) + bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] + self.linear_timesteps_weights = bsmntw_weighing + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): + self.sigmas, self.timesteps = self.set_timesteps_fn( + num_inference_steps=num_inference_steps, + denoising_strength=denoising_strength, + **kwargs, + ) + if training: + self.set_training_weight() + self.training = True + else: + self.training = False + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights diff --git a/DiffSynth-Studio/diffsynth/diffusion/logger.py b/DiffSynth-Studio/diffsynth/diffusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2792f1bb2cf61fe6756aa8316bcb16950d357a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/logger.py @@ -0,0 +1,43 @@ +import os, torch +from accelerate import Accelerator + + +class ModelLogger: + def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x): + self.output_path = output_path + self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + self.state_dict_converter = state_dict_converter + self.num_steps = 0 + + + def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) + accelerator.save(state_dict, path, safe_serialization=True) diff --git a/DiffSynth-Studio/diffsynth/diffusion/loss.py b/DiffSynth-Studio/diffsynth/diffusion/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a07e6878bb72d2370027844ef90ab50ad0e710ae --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/loss.py @@ -0,0 +1,145 @@ +from .base_pipeline import BasePipeline +import torch +import copy + +def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) + + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + + + noise = torch.randn_like(inputs["input_latents"]) + + origin_latents = copy.deepcopy(inputs["input_latents"]) + noisy_latents = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) + tgt_latent_len = noisy_latents.shape[2] // 2 + noisy_latents[:, :, tgt_latent_len:, ...] = origin_latents[:, :, tgt_latent_len:, ...] + inputs["latents"] = noisy_latents + + if "first_frame_latents" in inputs: + inputs["latents"][:, :, 0:1] = inputs['first_frame_latents'] + + + training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) + + + diff = (noise_pred[:, :, 1:tgt_latent_len] - training_target[:, :, 1:tgt_latent_len])**2 + # diff: [B,C,T,H,W] + + gamma = 0.01 + T = tgt_latent_len + i = torch.arange(1, T, device=diff.device).float() + + d = torch.abs(2 * i / (T - 1) - 1.0) + w_f = 1.0 + gamma * d**2 # [T] + + w_f = w_f.view(1,1,T-1,1,1) + + loss = (diff * w_f).mean() + loss = loss * pipe.scheduler.training_weight(timestep) + + + return loss + + +def DirectDistillLoss(pipe: BasePipeline, **inputs): + pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) + pipe.scheduler.training = True + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) + inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) + loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) + return loss + + +class TrajectoryImitationLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.initialized = False + + def initialize(self, device): + import lpips # TODO: remove it + self.loss_fn = lpips.LPIPS(net='alex').to(device) + self.initialized = True + + def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + trajectory = [inputs_shared["latents"].clone()] + + pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + trajectory.append(inputs_shared["latents"].clone()) + return pipe.scheduler.timesteps, trajectory + + def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + loss = 0 + pipe.scheduler.set_timesteps(num_inference_steps, training=True) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + + progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs()) + inputs_shared["latents"] = trajectory_teacher[progress_id_teacher] + + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + + sigma = pipe.scheduler.sigmas[progress_id] + sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1] + if progress_id + 1 >= len(pipe.scheduler.timesteps): + latents_ = trajectory_teacher[-1] + else: + progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) + latents_ = trajectory_teacher[progress_id_teacher] + + target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) + loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) + return loss + + def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + inputs_shared["latents"] = trajectory_teacher[0] + pipe.scheduler.set_timesteps(num_inference_steps) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + image_pred = pipe.vae_decoder(inputs_shared["latents"]) + image_real = pipe.vae_decoder(trajectory_teacher[-1]) + loss = self.loss_fn(image_pred.float(), image_real.float()) + return loss + + def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega): + if not self.initialized: + self.initialize(pipe.device) + with torch.no_grad(): + pipe.scheduler.set_timesteps(8) + timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2) + timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device) + loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss = loss_1 + loss_2 + return loss diff --git a/DiffSynth-Studio/diffsynth/diffusion/parsers.py b/DiffSynth-Studio/diffsynth/diffusion/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..68460b58a6f4ec5c2a8f7698ea039826418d2405 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/parsers.py @@ -0,0 +1,70 @@ +import argparse + + +def add_dataset_base_config(parser: argparse.ArgumentParser): + parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") + parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") + return parser + +def add_image_size_config(parser: argparse.ArgumentParser): + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.") + return parser + +def add_video_size_config(parser: argparse.ArgumentParser): + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") + return parser + +def add_model_config(parser: argparse.ArgumentParser): + parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") + parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") + parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") + parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.") + parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.") + return parser + +def add_training_config(parser: argparse.ArgumentParser): + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") + parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") + return parser + +def add_output_config(parser: argparse.ArgumentParser): + parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") + parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + return parser + +def add_lora_config(parser: argparse.ArgumentParser): + parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") + parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") + parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.") + parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.") + parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.") + return parser + +def add_gradient_config(parser: argparse.ArgumentParser): + parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") + parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + return parser + +def add_general_config(parser: argparse.ArgumentParser): + parser = add_dataset_base_config(parser) + parser = add_model_config(parser) + parser = add_training_config(parser) + parser = add_output_config(parser) + parser = add_lora_config(parser) + parser = add_gradient_config(parser) + return parser diff --git a/DiffSynth-Studio/diffsynth/diffusion/runner.py b/DiffSynth-Studio/diffsynth/diffusion/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..321557d240302d652b6c477f2444a5be427c4d78 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/runner.py @@ -0,0 +1,74 @@ +import os, torch +from tqdm import tqdm +from accelerate import Accelerator +from .training_module import DiffusionTrainingModule +from .logger import ModelLogger + + +def launch_training_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + learning_rate: float = 1e-5, + weight_decay: float = 1e-2, + num_workers: int = 1, + save_steps: int = None, + num_epochs: int = 1, + args = None, +): + if args is not None: + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_workers = args.dataset_num_workers + save_steps = args.save_steps + num_epochs = args.num_epochs + + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + + for epoch_id in range(num_epochs): + progress_bar = tqdm(dataloader, disable=not accelerator.is_main_process) + for data in progress_bar: + with accelerator.accumulate(model): + optimizer.zero_grad() + + loss = model(data) + accelerator.backward(loss) + optimizer.step() + + if accelerator.is_main_process: + progress_bar.set_postfix(loss=f"{loss.item():.4f}") + + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) + scheduler.step() + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) + + +def launch_data_process_task( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + num_workers: int = 8, + args = None, +): + if args is not None: + num_workers = args.dataset_num_workers + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) + model, dataloader = accelerator.prepare(model, dataloader) + + for data_id, data in enumerate(tqdm(dataloader)): + with accelerator.accumulate(model): + with torch.no_grad(): + folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) + os.makedirs(folder, exist_ok=True) + save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") + data = model(data) + torch.save(data, save_path) diff --git a/DiffSynth-Studio/diffsynth/diffusion/training_module.py b/DiffSynth-Studio/diffsynth/diffusion/training_module.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed8270b3882e8ab1696edfaf45c024f70c3d868 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/diffusion/training_module.py @@ -0,0 +1,316 @@ +import torch, json, os +from ..core import ModelConfig, load_state_dict +from ..utils.controlnet import ControlNetInput +from peft import LoraConfig, inject_adapter_in_model + + +class DiffusionTrainingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + def to(self, *args, **kwargs): + for name, model in self.named_children(): + model.to(*args, **kwargs) + return self + + + def trainable_modules(self): + trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) + return trainable_modules + + + def trainable_param_names(self): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + return trainable_param_names + + + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None): + if lora_alpha is None: + lora_alpha = lora_rank + if isinstance(target_modules, list) and len(target_modules) == 1: + target_modules = target_modules[0] + lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) + model = inject_adapter_in_model(lora_config, model) + if upcast_dtype is not None: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(upcast_dtype) + return model + + + def mapping_lora_state_dict(self, state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if "lora_A.weight" in key or "lora_B.weight" in key: + new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight") + new_state_dict[new_key] = value + elif "lora_A.default.weight" in key or "lora_B.default.weight" in key: + new_state_dict[key] = value + return new_state_dict + + + def export_trainable_state_dict(self, state_dict, remove_prefix=None): + trainable_param_names = self.trainable_param_names() + state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} + if remove_prefix is not None: + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith(remove_prefix): + name = name[len(remove_prefix):] + state_dict_[name] = param + state_dict = state_dict_ + return state_dict + + + def transfer_data_to_device(self, data, device, torch_float_dtype=None): + if data is None: + return data + elif isinstance(data, torch.Tensor): + data = data.to(device) + if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]: + data = data.to(torch_float_dtype) + return data + elif isinstance(data, tuple): + data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data) + return data + elif isinstance(data, list): + data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data) + return data + elif isinstance(data, dict): + data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data} + return data + else: + return data + + def parse_vram_config(self, fp8=False, offload=False, device="cpu"): + if fp8: + return { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": device, + "onload_dtype": torch.float8_e4m3fn, + "onload_device": device, + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + } + elif offload: + return { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + "clear_parameters": True, + } + else: + return {} + + def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"): + fp8_models = [] if fp8_models is None else fp8_models.split(",") + offload_models = [] if offload_models is None else offload_models.split(",") + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + for path in model_paths: + vram_config = self.parse_vram_config( + fp8=path in fp8_models, + offload=path in offload_models, + device=device + ) + model_configs.append(ModelConfig(path=path, **vram_config)) + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + for model_id_with_origin_path in model_id_with_origin_paths: + vram_config = self.parse_vram_config( + fp8=model_id_with_origin_path in fp8_models, + offload=model_id_with_origin_path in offload_models, + device=device + ) + config = self.parse_path_or_model_id(model_id_with_origin_path) + model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config)) + return model_configs + + + def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None): + if model_id_with_origin_path is None: + return default_value + elif os.path.exists(model_id_with_origin_path): + return ModelConfig(path=model_id_with_origin_path) + else: + if ":" not in model_id_with_origin_path: + raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.") + split_id = model_id_with_origin_path.rfind(":") + model_id = model_id_with_origin_path[:split_id] + origin_file_pattern = model_id_with_origin_path[split_id + 1:] + return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) + + + def auto_detect_lora_target_modules( + self, + model: torch.nn.Module, + search_for_linear=False, + linear_detector=lambda x: min(x.weight.shape) >= 512, + block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1, + name_prefix="", + ): + lora_target_modules = [] + if search_for_linear: + for name, module in model.named_modules(): + module_name = name_prefix + ["", "."][name_prefix != ""] + name + if isinstance(module, torch.nn.Linear) and linear_detector(module): + lora_target_modules.append(module_name) + else: + for name, module in model.named_children(): + module_name = name_prefix + ["", "."][name_prefix != ""] + name + lora_target_modules += self.auto_detect_lora_target_modules( + module, + search_for_linear=block_list_detector(module), + linear_detector=linear_detector, + block_list_detector=block_list_detector, + name_prefix=module_name, + ) + return lora_target_modules + + + def parse_lora_target_modules(self, model, lora_target_modules): + # if lora_target_modules == "": + # print("No LoRA target modules specified. The framework will automatically search for them.") + # lora_target_modules = self.auto_detect_lora_target_modules(model) + # print(f"LoRA will be patched at {lora_target_modules}.") + # else: + # lora_target_modules = lora_target_modules.split(",") + # return lora_target_modules + + if lora_target_modules == "": + short_names = self.auto_detect_lora_target_modules(model) + else: + short_names = [x.strip() for x in lora_target_modules.split(",")] + + matched = [] + + for name, module in model.named_modules(): + + # ⭐ 核心:直接排除 i2v_adapter + if "i2v_adapter" in name: + continue + + last = name.split(".")[-1] + + for s in short_names: + + if "." in s: + if name.endswith(s): + matched.append(name) + break + else: + if last == s: + matched.append(name) + break + + # 去重,保持稳定顺序 + seen = set() + final = [] + for n in matched: + if n not in seen: + final.append(n) + seen.add(n) + return final + + + + + def switch_pipe_to_training_mode( + self, + pipe, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + task="sft", + ): + # Scheduler + pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + + # Preset LoRA + if preset_lora_path is not None: + pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path) + + # FP8 + # FP8 relies on a model-specific memory management scheme. + # It is delegated to the subclass. + + # Add LoRA to the base models + if lora_base_model is not None and not task.endswith(":data_process"): + if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None: + print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.") + return + + target_modules = self.parse_lora_target_modules( + getattr(pipe, lora_base_model), + lora_target_modules + ) + + model = self.add_lora_to_model( + getattr(pipe, lora_base_model), + target_modules=target_modules, + lora_rank=lora_rank, + upcast_dtype=pipe.torch_dtype, + ) + + for name, p in self.pipe.named_parameters(): + if "i2v_adapter" in name: + if "norm" in name: + p.requires_grad = False + else: + p.requires_grad = True + + + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(pipe, lora_base_model, model) + + + def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None): + models_require_backward = [] + if trainable_models is not None: + models_require_backward += trainable_models.split(",") + if lora_base_model is not None: + models_require_backward += [lora_base_model] + if task.endswith(":data_process"): + _, pipe.units = pipe.split_pipeline_units(models_require_backward) + elif task.endswith(":train"): + pipe.units, _ = pipe.split_pipeline_units(models_require_backward) + return pipe + + def parse_extra_inputs(self, data, extra_inputs, inputs_shared): + controlnet_keys_map = ( + ("blockwise_controlnet_", "blockwise_controlnet_inputs",), + ("controlnet_", "controlnet_inputs"), + ) + controlnet_inputs = {} + for extra_input in extra_inputs: + for prefix, name in controlnet_keys_map: + if extra_input.startswith(prefix): + if name not in controlnet_inputs: + controlnet_inputs[name] = {} + controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input] + break + else: + inputs_shared[extra_input] = data[extra_input] + for name, params in controlnet_inputs.items(): + inputs_shared[name] = [ControlNetInput(**params)] + return inputs_shared diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/longcat_video_dit.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/longcat_video_dit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709e6b95fac59fbc1ba46d1ec39618ac74ee6152 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/longcat_video_dit.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/model_loader.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/model_loader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5421ae306ddcbc1ae1078c61e638a34186a85e3d Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/model_loader.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_animate_adapter.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_animate_adapter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee8671ba32c1ea972e67c01d568c73acf807a5e6 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_animate_adapter.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_camera_controller.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_camera_controller.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a237028347f65b2b826514fd29a561f7d746f44 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_camera_controller.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_dit.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_dit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f29deb3461a4e144b7c24d3b121cdf9a9bd528b5 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_dit.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_dit_s2v.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_dit_s2v.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cbf23e2a496bb137981fef4e927803829a1c4b7 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_dit_s2v.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fa5d6050367ff0d37a714d9ff283f50362cbb0 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_mot.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_mot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..447c0cf573a3907ae5c829f38e92affdbb299b80 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_mot.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d1dacd7f9306d9905f6a582c915183ebe5b150c Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a62ed8c83990a5c88b368323e44e29de2a3d89 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_vace.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_vace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..407dd589558952dfd0525193f866556c37768fcc Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_vace.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_vae.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_vae.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adb17ebf2cdff7cfb66c30f0259de7945639cff3 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wan_video_vae.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/__pycache__/wav2vec.cpython-39.pyc b/DiffSynth-Studio/diffsynth/models/__pycache__/wav2vec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4140aa2254df2d35cc312d1c44ceed357eaed9fe Binary files /dev/null and b/DiffSynth-Studio/diffsynth/models/__pycache__/wav2vec.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/models/dinov3_image_encoder.py b/DiffSynth-Studio/diffsynth/models/dinov3_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c394a0315306534f8209e74d6953d6061299bf23 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/dinov3_image_encoder.py @@ -0,0 +1,96 @@ +from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast +from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig +import torch + +from ..core.device.npu_compatible_device import get_device_type + + +class DINOv3ImageEncoder(DINOv3ViTModel): + def __init__(self): + config = DINOv3ViTConfig( + architectures = [ + "DINOv3ViTModel" + ], + attention_dropout = 0.0, + drop_path_rate = 0.0, + dtype = "float32", + hidden_act = "silu", + hidden_size = 4096, + image_size = 224, + initializer_range = 0.02, + intermediate_size = 8192, + key_bias = False, + layer_norm_eps = 1e-05, + layerscale_value = 1.0, + mlp_bias = True, + model_type = "dinov3_vit", + num_attention_heads = 32, + num_channels = 3, + num_hidden_layers = 40, + num_register_tokens = 4, + patch_size = 16, + pos_embed_jitter = None, + pos_embed_rescale = 2.0, + pos_embed_shift = None, + proj_bias = True, + query_bias = False, + rope_theta = 100.0, + transformers_version = "4.56.1", + use_gated_mlp = True, + value_bias = False + ) + super().__init__(config) + self.processor = DINOv3ViTImageProcessorFast( + crop_size = None, + data_format = "channels_first", + default_to_square = True, + device = None, + disable_grouping = None, + do_center_crop = None, + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.485, + 0.456, + 0.406 + ], + image_processor_type = "DINOv3ViTImageProcessorFast", + image_std = [ + 0.229, + 0.224, + 0.225 + ], + input_data_format = None, + resample = 2, + rescale_factor = 0.00392156862745098, + return_tensors = None, + size = { + "height": 224, + "width": 224 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) + bool_masked_pos = None + head_mask = None + + pixel_values = pixel_values.to(torch_dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module( + hidden_states, + attention_mask=layer_head_mask, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return pooled_output diff --git a/DiffSynth-Studio/diffsynth/models/flux2_dit.py b/DiffSynth-Studio/diffsynth/models/flux2_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..316cf0840bf4d0697ff50e446ec699717ade95c0 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux2_dit.py @@ -0,0 +1,1048 @@ +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch, math +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AdaLayerNormContinuous(nn.Module): + r""" + Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + +def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class Flux2SwiGLU(nn.Module): + """ + Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection + layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. + """ + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + x = self.gate_fn(x1) * x2 + return x + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: float = 3.0, + inner_dim: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + # Flux2SwiGLU will reduce the dimension by half + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + x = self.linear_out(x) + return x + + +class Flux2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Flux2Attention(torch.nn.Module): + _default_processor_cls = Flux2AttnProcessor + _available_processors = [Flux2AttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2ParallelSelfAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2ParallelSelfAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Parallel in (QKV + MLP in) projection + hidden_states = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + # Handle the attention logic + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + +class Flux2ParallelSelfAttention(torch.nn.Module): + """ + Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. + + This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) + input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B + paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. + """ + + _default_processor_cls = Flux2ParallelSelfAttnProcessor + _available_processors = [Flux2ParallelSelfAttnProcessor] + # Does not support QKV fusion as the QKV projections are always fused + _supports_qkv_fusion = False + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + # Fused QKV projections + MLP input projection + self.to_qkv_mlp_proj = torch.nn.Linear( + self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias + ) + self.mlp_act_fn = Flux2SwiGLU() + + # QK Norm + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + + # Fused attention output projection + MLP output projection + self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this + # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) + # for a visual depiction of this type of transformer block. + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + processor=Flux2ParallelSelfAttnProcessor(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + split_hidden_states: bool = False, + text_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already + # concatenated + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + else: + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + processor=Flux2AttnProcessor(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + # Modulation parameters shape: [1, 1, self.dim] + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + # Img stream + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + # Conditioning txt stream + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + # Attention on concatenated img + txt stream + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the image stream (`hidden_states`). + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + # Process attention outputs for the text stream (`encoder_hidden_states`). + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + # Expected ids shape: [S, len(self.axes_dim)] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + # Return tuple of 3-tuples of modulation params shift/scale/gate + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2DiT(torch.nn.Module): + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: Optional[int] = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + guidance_embeds: bool = True, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Sinusoidal positional embedding for RoPE on image and text tokens + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) + + # 2. Combined timestep + guidance embedding + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) + # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + # 4. Input projections + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + # 5. Double Stream Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + # 6. Single Stream Transformer Blocks + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + # 7. Output layers + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # 0. Handle input arguments + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + num_txt_tokens = encoder_hidden_states.shape[1] + + # 1. Calculate timestep embedding and modulation parameters + timestep = timestep.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # 3. Calculate RoPE embeddings from image and text tokens + # NOTE: the below logic means that we can't support batched inference with images of different resolutions or + # text prompts of differents lengths. Is this a use case we want to support? + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # 4. Double Stream Transformer Blocks + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Concatenate text and image streams for single-block inference + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 5. Single Stream Transformer Blocks + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + # Remove text tokens from concatenated stream + hidden_states = hidden_states[:, num_txt_tokens:, ...] + + # 6. Output layers + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return output diff --git a/DiffSynth-Studio/diffsynth/models/flux2_text_encoder.py b/DiffSynth-Studio/diffsynth/models/flux2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c68411f3160655ebefd49dfb6424b19373a301 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux2_text_encoder.py @@ -0,0 +1,58 @@ +from transformers import Mistral3ForConditionalGeneration, Mistral3Config + + +class Flux2TextEncoder(Mistral3ForConditionalGeneration): + def __init__(self): + config = Mistral3Config(**{ + "architectures": [ + "Mistral3ForConditionalGeneration" + ], + "dtype": "bfloat16", + "image_token_index": 10, + "model_type": "mistral3", + "multimodal_projector_bias": False, + "projector_hidden_act": "gelu", + "spatial_merge_size": 2, + "text_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 131072, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000000.0, + "sliding_window": None, + "use_cache": True, + "vocab_size": 131072 + }, + "transformers_version": "4.57.1", + "vision_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 1024, + "image_size": 1540, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "pixtral", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "rope_theta": 10000.0 + }, + "vision_feature_layer": -1 + }) + super().__init__(config) + + def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs): + return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs) + diff --git a/DiffSynth-Studio/diffsynth/models/flux2_vae.py b/DiffSynth-Studio/diffsynth/models/flux2_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c7904b17618cc3c0811e42fde0f80ecd8f15f7ee --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux2_vae.py @@ -0,0 +1,2322 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Optional, Tuple, Union, Callable + +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +import inspect + +ACT2CLS = { + "swish": nn.SiLU, + "silu": nn.SiLU, + "mish": nn.Mish, + "gelu": nn.GELU, + "relu": nn.ReLU, +} + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACT2CLS: + return ACT2CLS[act_fn]() + else: + raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + groups_out (`int`, *optional*, default to None): + The number of groups to use for the second normalization layer. if set to None, same as `groups`. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. + By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a + stronger conditioning with scale and shift. + kernel (`torch.Tensor`, optional, default to None): FIR filter, see + [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + use_in_shortcut (`bool`, *optional*, default to `True`): + If `True`, add a 1x1 nn.conv2d layer for skip-connection. + up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. + down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. + conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the + `conv_shortcut` output. + conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. + If None, same as `out_channels`. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", # default, scale_shift, + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + ): + super().__init__() + if time_embedding_norm == "ada_group": + raise ValueError( + "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", + ) + if time_embedding_norm == "spatial": + raise ValueError( + "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", + ) + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError( + f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" + ) + time_scale, time_shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + time_scale) + time_shift + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor.contiguous()) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class Downsample2D(nn.Module): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + if use_conv: + conv = nn.Conv2d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 + # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + # upsample_nearest_nhwc also fails when the number of output elements is large + # https://github.com/pytorch/pytorch/issues/141831 + scale_factor = ( + 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])]) + ) + if hidden_states.numel() * scale_factor > pow(2, 31): + hidden_states = hidden_states.contiguous() + + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # Cast back to original dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + is_causal: bool = False, + ): + super().__init__() + + # To prevent circular import. + # from .normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applies qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "layer_norm": + self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # Wan applies qk norm across all heads + # Wan also doesn't apply a q norm + self.norm_added_q = None + self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_xla_flash_attention( + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, + ) -> None: + r""" + Set whether to use xla flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + """ + if use_xla_flash_attention: + if not is_torch_xla_available: + raise "torch_xla is not available" + elif is_torch_xla_version("<", "2.3"): + raise "flash attention pallas kernel is supported from torch_xla version 2.3" + elif is_spmd() and is_torch_xla_version("<", "2.4"): + raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" + else: + if is_flux: + processor = XLAFluxFlashAttnProcessor2_0(partition_spec) + else: + processor = XLAFlashAttnProcessor2_0(partition_spec) + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + is_ip_adapter = hasattr(self, "processor") and isinstance( + self.processor, + (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), + ) + is_joint_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + JointAttnProcessor2_0, + XFormersJointAttnProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + elif is_ip_adapter: + processor = IPAdapterXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) + elif is_joint_processor: + processor = XFormersJointAttnProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_ip_adapter: + processor = IPAdapterAttnProcessor2_0( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0, output_size=attention_mask.shape[0] * head_size + ) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=1, output_size=attention_mask.shape[1] * head_size + ) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock2D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + # attention_head_dim=output_channel, + # temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # down + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock2D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + # prev_output_channel=prev_output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + # attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + latent_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Flux2VAE(torch.nn.Module): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types: Tuple[str, ...] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels: Tuple[int, ...] = ( + 128, + 256, + 512, + 512, + ), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 32, + norm_num_groups: int = 32, + sample_size: int = 1024, # YiYi notes: not sure + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + batch_norm_eps: float = 1e-4, + batch_norm_momentum: float = 0.1, + patch_size: Tuple[int, int] = (2, 2), + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) + + self.use_slicing = False + self.use_tiling = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ): + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + + h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2) + h = h[:, :128] + latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype) + latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( + h.device, h.dtype + ) + h = (h - latents_bn_mean) / latents_bn_std + return h + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return dec + + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ): + latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype) + latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( + z.device, z.dtype + ) + z = z * latents_bn_std + latents_bn_mean + z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2) + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True): + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + return moments + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ): + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return dec diff --git a/DiffSynth-Studio/diffsynth/models/flux_controlnet.py b/DiffSynth-Studio/diffsynth/models/flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb1138bb74f7b55c6e92ac312098e4168829f0a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_controlnet.py @@ -0,0 +1,384 @@ +import torch +from einops import rearrange, repeat +from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm +# from .utils import hash_state_dict_keys, init_weights_on_device +from contextlib import contextmanager + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + +class FluxControlNet(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(64, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)]) + + self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)]) + self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)]) + + self.mode_dict = mode_dict + self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None + self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072) + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states): + if len(res_stack) == 0: + return [torch.zeros_like(hidden_states)] * num_blocks + interval = (num_blocks + len(res_stack) - 1) // len(res_stack) + aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)] + return aligned_res_stack + + + def forward( + self, + hidden_states, + controlnet_conditioning, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + processor_id=None, + tiled=False, tile_size=128, tile_stride=64, + **kwargs + ): + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) + if self.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) + prompt_emb = self.context_embedder(prompt_emb) + if self.controlnet_mode_embedder is not None: # Different from FluxDiT + processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int) + processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device) + prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1) + text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + + hidden_states = self.patchify(hidden_states) + hidden_states = self.x_embedder(hidden_states) + controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT + + controlnet_res_stack = [] + for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_res_stack.append(controlnet_block(hidden_states)) + + controlnet_single_res_stack = [] + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:])) + + controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:]) + controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:]) + + return controlnet_res_stack, controlnet_single_res_stack + + + # @staticmethod + # def state_dict_converter(): + # return FluxControlNetStateDictConverter() + + def quantize(self): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + bias = None + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) + return weight, bias + + class quantized_layer: + class QLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight,bias= cast_bias_weight(self,input) + return torch.nn.functional.linear(input,weight,bias) + + class QRMSNorm(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,hidden_states,**kwargs): + weight= cast_weight(self.module,hidden_states) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) * weight + return hidden_states + + class QEmbedding(torch.nn.Embedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight= cast_weight(self,input) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def replace_layer(model): + for name, module in model.named_children(): + if isinstance(module,quantized_layer.QRMSNorm): + continue + if isinstance(module, torch.nn.Linear): + with init_weights_on_device(): + new_layer = quantized_layer.QLinear(module.in_features,module.out_features) + new_layer.weight = module.weight + if module.bias is not None: + new_layer.bias = module.bias + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + if hasattr(module,"quantized"): + continue + module.quantized= True + new_layer = quantized_layer.QRMSNorm(module) + setattr(model, name, new_layer) + elif isinstance(module,torch.nn.Embedding): + rows, cols = module.weight.shape + new_layer = quantized_layer.QEmbedding( + num_embeddings=rows, + embedding_dim=cols, + _weight=module.weight, + # _freeze=module.freeze, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse) + setattr(model, name, new_layer) + else: + replace_layer(module) + + replace_layer(self) + + + +class FluxControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + hash_value = hash_state_dict_keys(state_dict) + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + if hash_value == "78d18b9101345ff695f312e7e62538c0": + extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}} + elif hash_value == "b001c89139b5f053c715fe772362dd2a": + extra_kwargs = {"num_single_blocks": 0} + elif hash_value == "52357cb26250681367488a8954c271e8": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4} + elif hash_value == "0cfd1740758423a2a854d67c136d1e8c": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1} + elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10} + elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0} + else: + extra_kwargs = {} + return state_dict_, extra_kwargs + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/DiffSynth-Studio/diffsynth/models/flux_dit.py b/DiffSynth-Studio/diffsynth/models/flux_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..51a6e7f049d15e764383487c0edbf7da839b5918 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_dit.py @@ -0,0 +1,395 @@ +import torch +from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm +from einops import rearrange + + +def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size, num_tokens = hidden_states.shape[0:2] + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + +class RoPEEmbedding(torch.nn.Module): + def __init__(self, dim, theta, axes_dim): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + + + +class FluxJointAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.only_out_a = only_out_a + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + self.norm_q_b = RMSNorm(head_dim, eps=1e-6) + self.norm_k_b = RMSNorm(head_dim, eps=1e-6) + + self.a_to_out = torch.nn.Linear(dim_a, dim_a) + if not only_out_a: + self.b_to_out = torch.nn.Linear(dim_b, dim_b) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states_a.shape[0] + + # Part A + qkv_a = self.a_to_qkv(hidden_states_a) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v_a = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + # Part B + qkv_b = self.b_to_qkv(hidden_states_b) + qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_b, k_b, v_b = qkv_b.chunk(3, dim=1) + q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b) + + q = torch.concat([q_b, q_a], dim=2) + k = torch.concat([k_b, k_a], dim=2) + v = torch.concat([v_b, v_a], dim=2) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] + if ipadapter_kwargs_list is not None: + hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list) + hidden_states_a = self.a_to_out(hidden_states_a) + if self.only_out_a: + return hidden_states_a + else: + hidden_states_b = self.b_to_out(hidden_states_b) + return hidden_states_a, hidden_states_b + + + +class FluxJointTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.norm1_a = AdaLayerNorm(dim) + self.norm1_b = AdaLayerNorm(dim) + + self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads) + + self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_a = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_b = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + + # Attention + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + + # Part A + hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a + hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) + + # Part B + hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b + norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b + hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) + + return hidden_states_a, hidden_states_b + + + +class FluxSingleAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def forward(self, hidden_states, image_rotary_emb): + batch_size = hidden_states.shape[0] + + qkv_a = self.a_to_qkv(hidden_states) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + q, k = self.apply_rope(q_a, k_a, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + return hidden_states + + + +class AdaLayerNormSingle(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, 3 * dim, bias=True) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + + +class FluxSingleTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = dim // num_attention_heads + self.dim = dim + + self.norm = AdaLayerNormSingle(dim) + self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4)) + self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6) + + self.proj_out = torch.nn.Linear(dim * 5, dim) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states.shape[0] + + qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q, k, v = qkv.chunk(3, dim=1) + q, k = self.norm_q_a(q), self.norm_k_a(k) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + if ipadapter_kwargs_list is not None: + hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list) + return hidden_states + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + residual = hidden_states_a + norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) + hidden_states_a = self.to_qkv_mlp(norm_hidden_states) + attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] + + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") + + hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) + hidden_states_a = residual + hidden_states_a + + return hidden_states_a, hidden_states_b + + + +class AdaLayerNormContinuous(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, dim * 2, bias=True) + self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) + + def forward(self, x, conditioning): + emb = self.linear(self.silu(conditioning)) + shift, scale = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None] + shift[:, None] + return x + + + +class FluxDiT(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) + + self.final_norm_out = AdaLayerNormContinuous(3072) + self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def unpatchify(self, hidden_states, height, width): + hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) + return hidden_states + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask + + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim): + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, + use_gradient_checkpointing=False, + **kwargs + ): + # (Deprecated) The real forward is in `pipelines.flux_image`. + return None diff --git a/DiffSynth-Studio/diffsynth/models/flux_infiniteyou.py b/DiffSynth-Studio/diffsynth/models/flux_infiniteyou.py new file mode 100644 index 0000000000000000000000000000000000000000..861538a4b02fb6a52edee662b6efcd60f78f6916 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_infiniteyou.py @@ -0,0 +1,129 @@ +import math +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class InfiniteYouImageProjector(nn.Module): + + def __init__( + self, + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=8, + embedding_dim=512, + output_dim=4096, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + latents = latents.to(dtype=x.dtype, device=x.device) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + @staticmethod + def state_dict_converter(): + return FluxInfiniteYouImageProjectorStateDictConverter() + + +class FluxInfiniteYouImageProjectorStateDictConverter: + + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict['image_proj'] diff --git a/DiffSynth-Studio/diffsynth/models/flux_ipadapter.py b/DiffSynth-Studio/diffsynth/models/flux_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..31176fc2c2a508388502b45dc27e4d2218f16eec --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_ipadapter.py @@ -0,0 +1,110 @@ +from .general_modules import RMSNorm +from transformers import SiglipVisionModel, SiglipVisionConfig +import torch + + +class SiglipVisionModelSO400M(SiglipVisionModel): + def __init__(self): + config = SiglipVisionConfig( + hidden_size=1152, + image_size=384, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + patch_size=14, + architectures=["SiglipModel"], + initializer_factor=1.0, + torch_dtype="float32", + transformers_version="4.37.0.dev0" + ) + super().__init__(config) + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +class IpAdapterModule(torch.nn.Module): + def __init__(self, num_attention_heads, attention_head_dim, input_dim): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + output_dim = num_attention_heads * attention_head_dim + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False) + + + def forward(self, hidden_states): + batch_size = hidden_states.shape[0] + # ip_k + ip_k = self.to_k_ip(hidden_states) + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_k = self.norm_added_k(ip_k) + # ip_v + ip_v = self.to_v_ip(hidden_states) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + return ip_k, ip_v + + +class FluxIpAdapter(torch.nn.Module): + def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57): + super().__init__() + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)]) + self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens) + self.set_adapter() + + def set_adapter(self): + self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for block_id in self.call_block_id: + ipadapter_id = self.call_block_id[block_id] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + ip_kv_dict[block_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + @staticmethod + def state_dict_converter(): + return FluxIpAdapterStateDictConverter() + + +class FluxIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/DiffSynth-Studio/diffsynth/models/flux_lora_encoder.py b/DiffSynth-Studio/diffsynth/models/flux_lora_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..13589b0611f3140479ef4faa3b7a29371caa447b --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_lora_encoder.py @@ -0,0 +1,521 @@ +import torch +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ + + + +class LoRALayerBlock(torch.nn.Module): + def __init__(self, L, dim_in, dim_out): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) + self.layer_norm = torch.nn.LayerNorm(dim_out) + + def forward(self, lora_A, lora_B): + x = self.x @ lora_A.T @ lora_B.T + x = self.layer_norm(x) + return x + + +class LoRAEmbedder(torch.nn.Module): + def __init__(self, lora_patterns=None, L=1, out_dim=2048): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1]) + self.model_dict = torch.nn.ModuleDict(model_dict) + + proj_dict = {} + for lora_pattern in lora_patterns: + layer_type, dim = lora_pattern["type"], lora_pattern["dim"] + if layer_type not in proj_dict: + proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim) + self.proj_dict = torch.nn.ModuleDict(proj_dict) + + self.lora_patterns = lora_patterns + + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), + "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + return lora_patterns + + def forward(self, lora): + lora_emb = [] + for lora_pattern in self.lora_patterns: + name, layer_type = lora_pattern["name"], lora_pattern["type"] + lora_A = lora[name + ".lora_A.weight"] + lora_B = lora[name + ".lora_B.weight"] + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) + lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) + lora_emb.append(lora_out) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + +class FluxLoRAEncoder(torch.nn.Module): + def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1): + super().__init__() + self.num_embeds_per_lora = num_embeds_per_lora + # embedder + self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)]) + + # special embedding + self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim)) + self.num_special_embeds = num_special_embeds + + # final layer + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + self.final_linear = torch.nn.Linear(embed_dim, embed_dim) + + def forward(self, lora): + lora_embeds = self.embedder(lora) + special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device) + embeds = torch.concat([special_embeds, lora_embeds], dim=1) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds) + embeds = embeds[:, :self.num_special_embeds] + embeds = self.final_layer_norm(embeds) + embeds = self.final_linear(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return FluxLoRAEncoderStateDictConverter() + + +class FluxLoRAEncoderStateDictConverter: + def from_civitai(self, state_dict): + return state_dict diff --git a/DiffSynth-Studio/diffsynth/models/flux_lora_patcher.py b/DiffSynth-Studio/diffsynth/models/flux_lora_patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8fc8cea03bbbc658d8e1869432d903b6ae7ce9 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_lora_patcher.py @@ -0,0 +1,306 @@ +import torch, math +from ..core.loader import load_state_dict +from typing import Union + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_B." not in key: + continue + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + updated_num = 0 + lora_name_dict = self.get_name_dict(state_dict_lora) + for name, module in model.named_modules(): + if name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict = module.state_dict() + state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict) + updated_num += 1 + print(f"{updated_num} tensors are updated by LoRA.") + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().load(model, state_dict_lora, alpha) + + + def convert_state_dict(self,state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class FluxLoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) diff --git a/DiffSynth-Studio/diffsynth/models/flux_text_encoder_clip.py b/DiffSynth-Studio/diffsynth/models/flux_text_encoder_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1425423ce6d1df946198a16a7e96078ab8fed807 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_text_encoder_clip.py @@ -0,0 +1,112 @@ +import torch + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FluxTextEncoderClip(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=2, extra_mask=None): + embeds = self.token_embedding(input_ids) + embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device) + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + if extra_mask is not None: + attn_mask[:, extra_mask[0]==0] = float("-inf") + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + hidden_states = embeds + embeds = self.final_layer_norm(embeds) + pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] + return pooled_embeds, hidden_states diff --git a/DiffSynth-Studio/diffsynth/models/flux_text_encoder_t5.py b/DiffSynth-Studio/diffsynth/models/flux_text_encoder_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..ee72e4a89b2089b62c6ea86c4aef91755ee9ee9a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_text_encoder_t5.py @@ -0,0 +1,43 @@ +import torch +from transformers import T5EncoderModel, T5Config + + +class FluxTextEncoderT5(T5EncoderModel): + def __init__(self): + config = T5Config(**{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "dtype": "bfloat16", + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": True, + "is_gated_act": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": True, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": False, + "transformers_version": "4.57.1", + "use_cache": True, + "vocab_size": 32128 + }) + super().__init__(config) + + def forward(self, input_ids): + outputs = super().forward(input_ids=input_ids) + prompt_emb = outputs.last_hidden_state + return prompt_emb diff --git a/DiffSynth-Studio/diffsynth/models/flux_vae.py b/DiffSynth-Studio/diffsynth/models/flux_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..5eabeaee6ad54f0b1c02f1b24cd2ccd84d0238bf --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_vae.py @@ -0,0 +1,451 @@ +import torch +from einops import rearrange, repeat + + +class TileWorker: + def __init__(self): + pass + + + def mask(self, height, width, border_width): + # Create a mask with shape (height, width). + # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. + x = torch.arange(height).repeat(width, 1).T + y = torch.arange(width).repeat(height, 1) + mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values + mask = (mask / border_width).clip(0, 1) + return mask + + + def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): + # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) + batch_size, channel, _, _ = model_input.shape + model_input = model_input.to(device=tile_device, dtype=tile_dtype) + unfold_operator = torch.nn.Unfold( + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + model_input = unfold_operator(model_input) + model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) + + return model_input + + + def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): + # Call y=forward_fn(x) for each tile + tile_num = model_input.shape[-1] + model_output_stack = [] + + for tile_id in range(0, tile_num, tile_batch_size): + + # process input + tile_id_ = min(tile_id + tile_batch_size, tile_num) + x = model_input[:, :, :, :, tile_id: tile_id_] + x = x.to(device=inference_device, dtype=inference_dtype) + x = rearrange(x, "b c h w n -> (n b) c h w") + + # process output + y = forward_fn(x) + y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) + y = y.to(device=tile_device, dtype=tile_dtype) + model_output_stack.append(y) + + model_output = torch.concat(model_output_stack, dim=-1) + return model_output + + + def io_scale(self, model_output, tile_size): + # Determine the size modification happened in forward_fn + # We only consider the same scale on height and width. + io_scale = model_output.shape[2] / tile_size + return io_scale + + + def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): + # The reversed function of tile + mask = self.mask(tile_size, tile_size, border_width) + mask = mask.to(device=tile_device, dtype=tile_dtype) + mask = rearrange(mask, "h w -> 1 1 h w 1") + model_output = model_output * mask + + fold_operator = torch.nn.Fold( + output_size=(height, width), + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) + model_output = rearrange(model_output, "b c h w n -> b (c h w) n") + model_output = fold_operator(model_output) / fold_operator(mask) + + return model_output + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + inference_device, inference_dtype = model_input.device, model_input.dtype + height, width = model_input.shape[2], model_input.shape[3] + border_width = int(tile_stride*0.5) if border_width is None else border_width + + # tile + model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) + + # inference + model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) + + # resize + io_scale = self.io_scale(model_output, tile_size) + height, width = int(height*io_scale), int(width*io_scale) + tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) + border_width = int(border_width*io_scale) + + # untile + model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) + + # Done! + model_output = model_output.to(device=inference_device, dtype=inference_dtype) + return model_output + + +class ConvAttention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q) + self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) + self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + q = self.to_q(conv_input) + q = rearrange(q[:, :, :, 0], "B C L -> B L C") + conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1") + k = self.to_k(conv_input) + v = self.to_v(conv_input) + k = rearrange(k[:, :, :, 0], "B C L -> B L C") + v = rearrange(v[:, :, :, 0], "B C L -> B L C") + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + conv_input = rearrange(hidden_states, "B L C -> B C L 1") + hidden_states = self.to_out(conv_input) + hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C") + + return hidden_states + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class VAEAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + + if use_conv_attention: + self.transformer_blocks = torch.nn.ModuleList([ + ConvAttention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + else: + self.transformer_blocks = torch.nn.ModuleList([ + Attention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class ResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x = hidden_states + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb)[:, :, None, None] + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +class UpSampler(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class DownSampler(torch.nn.Module): + def __init__(self, channels, padding=1, extra_padding=False): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) + self.extra_padding = extra_padding + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + if self.extra_padding: + hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class FluxVAEDecoder(torch.nn.Module): + def __init__(self, use_conv_attention=True): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x + + self.blocks = torch.nn.ModuleList([ + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention), + ResnetBlock(512, 512, eps=1e-6), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + UpSampler(256), + # UpDecoderBlock2D + ResnetBlock(256, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = sample / self.scaling_factor + self.shift_factor + hidden_states = self.conv_in(hidden_states) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class FluxVAEEncoder(torch.nn.Module): + def __init__(self, use_conv_attention=True): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # DownEncoderBlock2D + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + DownSampler(128, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(128, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + DownSampler(256, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(256, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + DownSampler(512, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention), + ResnetBlock(512, 512, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = self.conv_in(sample) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + hidden_states = hidden_states[:, :16] + hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor + + return hidden_states + + def encode_video(self, sample, batch_size=8): + B = sample.shape[0] + hidden_states = [] + + for i in range(0, sample.shape[2], batch_size): + + j = min(i + batch_size, sample.shape[2]) + sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W") + + hidden_states_batch = self(sample_batch) + hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B) + + hidden_states.append(hidden_states_batch) + + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states diff --git a/DiffSynth-Studio/diffsynth/models/flux_value_control.py b/DiffSynth-Studio/diffsynth/models/flux_value_control.py new file mode 100644 index 0000000000000000000000000000000000000000..549dbc93b41343a42266af11584e2e7d39a17cd6 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/flux_value_control.py @@ -0,0 +1,56 @@ +import torch +from .general_modules import TemporalTimesteps + + +class MultiValueEncoder(torch.nn.Module): + def __init__(self, encoders=()): + super().__init__() + if not isinstance(encoders, list): + encoders = [encoders] + self.encoders = torch.nn.ModuleList(encoders) + + def __call__(self, values, dtype): + emb = [] + for encoder, value in zip(self.encoders, values): + if value is not None: + value = value.unsqueeze(0) + emb.append(encoder(value, dtype)) + emb = torch.concat(emb, dim=0) + return emb + + +class SingleValueEncoder(torch.nn.Module): + def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None): + super().__init__() + self.prefer_len = prefer_len + self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) + self.prefer_value_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.positional_embedding = torch.nn.Parameter( + torch.randn(self.prefer_len, dim_out) + ) + + def forward(self, value, dtype): + value = value * 1000 + emb = self.prefer_proj(value).to(dtype) + emb = self.prefer_value_embedder(emb).squeeze(0) + base_embeddings = emb.expand(self.prefer_len, -1) + positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device) + learned_embeddings = base_embeddings + positional_embedding + return learned_embeddings + + @staticmethod + def state_dict_converter(): + return SingleValueEncoderStateDictConverter() + + +class SingleValueEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict diff --git a/DiffSynth-Studio/diffsynth/models/general_modules.py b/DiffSynth-Studio/diffsynth/models/general_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1e97ba6e80447ab9f10eed9424edbd6e3d147cb4 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/general_modules.py @@ -0,0 +1,146 @@ +import torch, math + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, + computation_device = None, + align_dtype_to_timestep = False, +): + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + if align_dtype_to_timestep: + emb = emb.to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.computation_device = computation_device + self.scale = scale + self.align_dtype_to_timestep = align_dtype_to_timestep + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + computation_device=self.computation_device, + scale=self.scale, + align_dtype_to_timestep=self.align_dtype_to_timestep, + ) + return t_emb + + +class DiffusersCompatibleTimestepProj(torch.nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.linear_1 = torch.nn.Linear(dim_in, dim_out) + self.act = torch.nn.SiLU() + self.linear_2 = torch.nn.Linear(dim_out, dim_out) + + def forward(self, x): + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + return x + + +class TimestepEmbeddings(torch.nn.Module): + def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False): + super().__init__() + self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) + if diffusers_compatible_format: + self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out) + else: + self.timestep_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = torch.nn.Embedding(2, dim_out) + + def forward(self, timestep, dtype, addition_t_cond=None): + time_emb = self.time_proj(timestep).to(dtype) + time_emb = self.timestep_embedder(time_emb) + if addition_t_cond is not None: + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=dtype) + time_emb = time_emb + addition_t_emb + return time_emb + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps, elementwise_affine=True): + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.ones((dim,))) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.to(input_dtype) + if self.weight is not None: + hidden_states = hidden_states * self.weight + return hidden_states + + +class AdaLayerNorm(torch.nn.Module): + def __init__(self, dim, single=False, dual=False): + super().__init__() + self.single = single + self.dual = dual + self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(torch.nn.functional.silu(emb)) + if self.single: + scale, shift = emb.unsqueeze(1).chunk(2, dim=2) + x = self.norm(x) * (1 + scale) + shift + return x + elif self.dual: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) + norm_x = self.norm(x) + x = norm_x * (1 + scale_msa) + shift_msa + norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/DiffSynth-Studio/diffsynth/models/longcat_video_dit.py b/DiffSynth-Studio/diffsynth/models/longcat_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe1c21b9304c484c72cb17e66ff97e7c85bc16b --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/longcat_video_dit.py @@ -0,0 +1,902 @@ +from typing import List, Optional, Tuple + +import math +import torch +import torch.nn as nn +import torch.amp as amp + +import numpy as np +import torch.nn.functional as F +from einops import rearrange, repeat +from .wan_video_dit import flash_attention +from ..core.device.npu_compatible_device import get_device_type +from ..core.gradient import gradient_checkpoint_forward + + +class RMSNorm_FP32(torch.nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class RotaryPositionalEmbedding(nn.Module): + + def __init__(self, + head_dim, + cp_split_hw=None + ): + """Rotary positional embedding for 3D + Reference : https://blog.eleuther.ai/rotary-embeddings/ + Paper: https://arxiv.org/pdf/2104.09864.pdf + Args: + dim: Dimension of embedding + base: Base value for exponential + """ + super().__init__() + self.head_dim = head_dim + assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.' + self.cp_split_hw = cp_split_hw + # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels + self.base = 10000 + self.freqs_dict = {} + + def register_grid_size(self, grid_size): + if grid_size not in self.freqs_dict: + self.freqs_dict.update({ + grid_size: self.precompute_freqs_cis_3d(grid_size) + }) + + def precompute_freqs_cis_3d(self, grid_size): + num_frames, height, width = grid_size + dim_t = self.head_dim - 4 * (self.head_dim // 6) + dim_h = 2 * (self.head_dim // 6) + dim_w = 2 * (self.head_dim // 6) + freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32) + grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32) + grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32) + grid_t = torch.from_numpy(grid_t).float() + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + # (T H W D) + freqs = rearrange(freqs, "T H W D -> (T H W) D") + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # with torch.no_grad(): + # freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width) + # freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw) + # freqs = rearrange(freqs, "T H W D -> (T H W) D") + + return freqs + + def forward(self, q, k, grid_size): + """3D RoPE. + + Args: + query: [B, head, seq, head_dim] + key: [B, head, seq, head_dim] + Returns: + query and key with the same shape as input. + """ + + if grid_size not in self.freqs_dict: + self.register_grid_size(grid_size) + + freqs_cis = self.freqs_dict[grid_size].to(q.device) + q_, k_ = q.float(), k.float() + freqs_cis = freqs_cis.float().to(q.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + q_ = (q_ * cos) + (rotate_half(q_) * sin) + k_ = (k_ * cos) + (rotate_half(k_) * sin) + + return q_.type_as(q), k_.type_as(k) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = None, + cp_split_hw: Optional[List[int]] = None + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + self.enable_bsa = enable_bsa + self.bsa_params = bsa_params + self.cp_split_hw = cp_split_hw + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.proj = nn.Linear(dim, dim) + + self.rope_3d = RotaryPositionalEmbedding( + self.head_dim, + cp_split_hw=cp_split_hw + ) + + def _process_attn(self, q, k, v, shape): + q = rearrange(q, "B H S D -> B S (H D)") + k = rearrange(k, "B H S D -> B S (H D)") + v = rearrange(v, "B H S D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads) + return x + + def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if return_kv: + k_cache, v_cache = k.clone(), v.clone() + + q, k = self.rope_3d(q, k, shape) + + # cond mode + if num_cond_latents is not None and num_cond_latents > 0: + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + # process the condition tokens + q_cond = q[:, :, :num_cond_latents_thw].contiguous() + k_cond = k[:, :, :num_cond_latents_thw].contiguous() + v_cond = v[:, :, :num_cond_latents_thw].contiguous() + x_cond = self._process_attn(q_cond, k_cond, v_cond, shape) + # process the noise tokens + q_noise = q[:, :, num_cond_latents_thw:].contiguous() + x_noise = self._process_attn(q_noise, k, v, shape) + # merge x_cond and x_noise + x = torch.cat([x_cond, x_noise], dim=2).contiguous() + else: + x = self._process_attn(q, k, v, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + if return_kv: + return x, (k_cache, v_cache) + else: + return x + + def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + T, H, W = shape + k_cache, v_cache = kv_cache + assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B] + if k_cache.shape[0] == 1: + k_cache = k_cache.repeat(B, 1, 1, 1) + v_cache = v_cache.repeat(B, 1, 1, 1) + + if num_cond_latents is not None and num_cond_latents > 0: + k_full = torch.cat([k_cache, k], dim=2).contiguous() + v_full = torch.cat([v_cache, v], dim=2).contiguous() + q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous() + q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) + q = q_padding[:, :, -N:].contiguous() + + x = self._process_attn(q, k_full, v_full, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads, + enable_flashattn3=False, + enable_flashattn2=False, + enable_xformers=False, + ): + super(MultiHeadCrossAttention, self).__init__() + assert dim % num_heads == 0, "d_model must be divisible by num_heads" + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = nn.Linear(dim, dim) + self.kv_linear = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + + def _process_cross_attn(self, x, cond, kv_seqlen): + B, N, C = x.shape + assert C == self.dim and cond.shape[2] == self.dim + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + q, k = self.q_norm(q), self.k_norm(k) + + q = rearrange(q, "B S H D -> B S (H D)") + k = rearrange(k, "B S H D -> B S (H D)") + v = rearrange(v, "B S H D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + + x = x.view(B, -1, C) + x = self.proj(x) + return x + + def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): + """ + x: [B, N, C] + cond: [B, M, C] + """ + if num_cond_latents is None or num_cond_latents == 0: + return self._process_cross_attn(x, cond, kv_seqlen) + else: + B, N, C = x.shape + if num_cond_latents is not None and num_cond_latents > 0: + assert shape is not None, "SHOULD pass in the shape" + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C] + output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C] + output = torch.cat([ + torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), + output_noise + ], dim=1).contiguous() + else: + raise NotImplementedError + + return output + + +class LayerNorm_FP32(nn.LayerNorm): + def __init__(self, dim, eps, elementwise_affine): + super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + out = F.layer_norm( + inputs.float(), + self.normalized_shape, + None if self.weight is None else self.weight.float(), + None if self.bias is None else self.bias.float() , + self.eps + ).to(origin_dtype) + return out + + +def modulate_fp32(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D) + # ensure the modulation params be fp32 + assert shift.dtype == torch.float32, scale.dtype == torch.float32 + dtype = x.dtype + x = norm_func(x.to(torch.float32)) + x = x * (scale + 1) + shift + x = x.to(dtype) + return x + + +class FinalLayer_FP32(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim): + super().__init__() + self.hidden_size = hidden_size + self.num_patch = num_patch + self.out_channels = out_channels + self.adaln_tembed_dim = adaln_tembed_dim + + self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True)) + + def forward(self, x, t, latent_shape): + # timestep shape: [B, T, C] + assert t.dtype == torch.float32 + B, N, C = x.shape + T, _, _ = latent_shape + + with amp.autocast(get_device_type(), dtype=torch.float32): + shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] + x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) + x = self.linear(x) + return x + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.dim = dim + self.hidden_dim = hidden_dim + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, t_embed_dim, frequency_embedding_size=256): + super().__init__() + self.t_embed_dim = t_embed_dim + self.frequency_embedding_size = frequency_embedding_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, t_embed_dim, bias=True), + nn.SiLU(), + nn.Linear(t_embed_dim, t_embed_dim, bias=True), + ) + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, in_channels, hidden_size): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.y_proj = nn.Sequential( + nn.Linear(in_channels, hidden_size, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, caption): + B, _, N, C = caption.shape + caption = self.y_proj(caption) + return caption + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + B, C, T, H, W = x.shape + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class LongCatSingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: int, + adaln_tembed_dim: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params=None, + cp_split_hw=None + ): + super().__init__() + + self.hidden_size = hidden_size + + # scale and gate modulation + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True) + ) + + self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True) + + self.attn = Attention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + self.cross_attn = MultiHeadCrossAttention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + ) + self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio)) + + def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False): + """ + x: [B, N, C] + y: [1, N_valid_tokens, C] + t: [B, T, C_t] + y_seqlen: [B]; type of a list + latent_shape: latent shape of a single item + """ + x_dtype = x.dtype + + B, N, C = x.shape + T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. + + # compute modulation params in fp32 + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + shift_msa, scale_msa, gate_msa, \ + shift_mlp, scale_mlp, gate_mlp = \ + self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] + + # self attn with modulation + x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C) + + if kv_cache is not None: + kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device)) + attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache) + else: + attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv) + + if return_kv: + x_s, kv_cache = attn_outputs + else: + x_s = attn_outputs + + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + # cross attn + if not skip_crs_attn: + if kv_cache is not None: + num_cond_latents = None + x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape) + + # ffn with modulation + x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) + x_s = self.ffn(x_m) + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + if return_kv: + return x, kv_cache + else: + return x + + +class LongCatVideoTransformer3DModel(torch.nn.Module): + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + hidden_size: int = 4096, + depth: int = 48, + num_heads: int = 32, + caption_channels: int = 4096, + mlp_ratio: int = 4, + adaln_tembed_dim: int = 512, + frequency_embedding_size: int = 256, + # default params + patch_size: Tuple[int] = (1, 2, 2), + # attention config + enable_flashattn3: bool = False, + enable_flashattn2: bool = True, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]}, + cp_split_hw: Optional[List[int]] = [1, 1], + text_tokens_zero_pad: bool = True, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.cp_split_hw = cp_split_hw + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + ) + + self.blocks = nn.ModuleList( + [ + LongCatSingleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + adaln_tembed_dim=adaln_tembed_dim, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer_FP32( + hidden_size, + np.prod(self.patch_size), + out_channels, + adaln_tembed_dim, + ) + + self.gradient_checkpointing = False + self.text_tokens_zero_pad = text_tokens_zero_pad + + self.lora_dict = {} + self.active_loras = [] + + def enable_loras(self, lora_key_list=[]): + self.disable_all_loras() + + module_loras = {} # {module_name: [lora1, lora2, ...]} + model_device = next(self.parameters()).device + model_dtype = next(self.parameters()).dtype + + for lora_key in lora_key_list: + if lora_key in self.lora_dict: + for lora in self.lora_dict[lora_key].loras: + lora.to(model_device, dtype=model_dtype, non_blocking=True) + module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") + if module_name not in module_loras: + module_loras[module_name] = [] + module_loras[module_name].append(lora) + self.active_loras.append(lora_key) + + for module_name, loras in module_loras.items(): + module = self._get_module_by_name(module_name) + if not hasattr(module, 'org_forward'): + module.org_forward = module.forward + module.forward = self._create_multi_lora_forward(module, loras) + + def _create_multi_lora_forward(self, module, loras): + def multi_lora_forward(x, *args, **kwargs): + weight_dtype = x.dtype + org_output = module.org_forward(x, *args, **kwargs) + + total_lora_output = 0 + for lora in loras: + if lora.use_lora: + lx = lora.lora_down(x.to(lora.lora_down.weight.dtype)) + lx = lora.lora_up(lx) + lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale + total_lora_output += lora_output + + return org_output + total_lora_output + + return multi_lora_forward + + def _get_module_by_name(self, module_name): + try: + module = self + for part in module_name.split('.'): + module = getattr(module, part) + return module + except AttributeError as e: + raise ValueError(f"Cannot find module: {module_name}, error: {e}") + + def disable_all_loras(self): + for name, module in self.named_modules(): + if hasattr(module, 'org_forward'): + module.forward = module.org_forward + delattr(module, 'org_forward') + + for lora_key, lora_network in self.lora_dict.items(): + for lora in lora_network.loras: + lora.to("cpu") + + self.active_loras.clear() + + def enable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = True + + def disable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = False + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask=None, + num_cond_latents=0, + return_kv=False, + kv_cache_dict={}, + skip_crs_attn=False, + offload_kv_cache=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + + B, _, T, H, W = hidden_states.shape + + N_t = T // self.patch_size[0] + N_h = H // self.patch_size[1] + N_w = W // self.patch_size[2] + + assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension." + + # expand the shape of timestep from [B] to [B, T] + if len(timestep.shape) == 1: + timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T] + timestep[:, :num_cond_latents] = 0 + + dtype = hidden_states.dtype + hidden_states = hidden_states.to(dtype) + timestep = timestep.to(dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype) + + hidden_states = self.x_embedder(hidden_states) # [B, N, C] + + with amp.autocast(device_type=get_device_type(), dtype=torch.float32): + t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] + + encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] + + if self.text_tokens_zero_pad and encoder_attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None] + encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype) + + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1) + encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C] + y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B] + else: + y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1]) + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w) + # hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw) + # hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C") + + # blocks + kv_cache_dict_ret = {} + for i, block in enumerate(self.blocks): + block_outputs = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=hidden_states, + y=encoder_hidden_states, + t=t, + y_seqlen=y_seqlens, + latent_shape=(N_t, N_h, N_w), + num_cond_latents=num_cond_latents, + return_kv=return_kv, + kv_cache=kv_cache_dict.get(i, None), + skip_crs_attn=skip_crs_attn, + ) + + if return_kv: + hidden_states, kv_cache = block_outputs + if offload_kv_cache: + kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu()) + else: + kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous()) + else: + hidden_states = block_outputs + + hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out] + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw) + + hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W] + + # cast to float32 for better accuracy + hidden_states = hidden_states.to(torch.float32) + + if return_kv: + return hidden_states, kv_cache_dict_ret + else: + return hidden_states + + + def unpatchify(self, x, N_t, N_h, N_w): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + @staticmethod + def state_dict_converter(): + return LongCatVideoTransformer3DModelDictConverter() + + +class LongCatVideoTransformer3DModelDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + diff --git a/DiffSynth-Studio/diffsynth/models/model_loader.py b/DiffSynth-Studio/diffsynth/models/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..16d72ddba8abb44dda24a03122d9fff79dd42660 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/model_loader.py @@ -0,0 +1,111 @@ +from ..core.loader import load_model, hash_model_file +from ..core.vram import AutoWrappedModule +from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS +import importlib, json, torch + + +class ModelPool: + def __init__(self): + self.model = [] + self.model_name = [] + self.model_path = [] + + def import_model_class(self, model_class): + split = model_class.rfind(".") + model_resource, model_class = model_class[:split], model_class[split+1:] + model_class = importlib.import_module(model_resource).__getattribute__(model_class) + return model_class + + def need_to_enable_vram_management(self, vram_config): + return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None + + def fetch_module_map(self, model_class, vram_config): + if self.need_to_enable_vram_management(vram_config): + if model_class in VRAM_MANAGEMENT_MODULE_MAPS: + module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()} + else: + module_map = {self.import_model_class(model_class): AutoWrappedModule} + else: + module_map = None + return module_map + + def load_model_file(self, config, path, vram_config, vram_limit=None): + model_class = self.import_model_class(config["model_class"]) + model_config = config.get("extra_kwargs", {}) + if "state_dict_converter" in config: + state_dict_converter = self.import_model_class(config["state_dict_converter"]) + else: + state_dict_converter = None + module_map = self.fetch_module_map(config["model_class"], vram_config) + model = load_model( + model_class, path, model_config, + vram_config["computation_dtype"], vram_config["computation_device"], + state_dict_converter, + use_disk_map=True, + vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, + ) + return model + + def default_vram_config(self): + vram_config = { + "offload_dtype": None, + "offload_device": None, + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cpu", + "computation_dtype": torch.bfloat16, + "computation_device": "cpu", + } + return vram_config + + def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): + print(f"Loading models from: {json.dumps(path, indent=4)}") + if vram_config is None: + vram_config = self.default_vram_config() + model_hash = hash_model_file(path) + loaded = False + for config in MODEL_CONFIGS: + if config["model_hash"] == model_hash: + model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) + if clear_parameters: self.clear_parameters(model) + self.model.append(model) + model_name = config["model_name"] + self.model_name.append(model_name) + self.model_path.append(path) + model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")} + print(f"Loaded model: {json.dumps(model_info, indent=4)}") + loaded = True + if not loaded: + raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}") + + def fetch_model(self, model_name, index=None): + fetched_models = [] + fetched_model_paths = [] + for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): + if model_name == model_name_: + fetched_models.append(model) + fetched_model_paths.append(model_path) + if len(fetched_models) == 0: + print(f"No {model_name} models available. This is not an error.") + model = None + elif len(fetched_models) == 1: + print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") + model = fetched_models[0] + else: + if index is None: + model = fetched_models[0] + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") + elif isinstance(index, int): + model = fetched_models[:index] + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.") + else: + model = fetched_models + print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.") + return model + + def clear_parameters(self, model: torch.nn.Module): + for name, module in model.named_children(): + self.clear_parameters(module) + for name, param in model.named_parameters(recurse=False): + setattr(model, name, None) diff --git a/DiffSynth-Studio/diffsynth/models/nexus_gen.py b/DiffSynth-Studio/diffsynth/models/nexus_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..011039842312b01a0bbb69b999bc868902736e9a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/nexus_gen.py @@ -0,0 +1,161 @@ +import torch +from PIL import Image + + +class NexusGenAutoregressiveModel(torch.nn.Module): + def __init__(self, max_length=1024, max_pixels=262640): + super(NexusGenAutoregressiveModel, self).__init__() + from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration + from transformers import Qwen2_5_VLConfig + self.max_length = max_length + self.max_pixels = max_pixels + model_config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLForConditionalGeneration(model_config) + self.processor = None + + + def load_processor(self, path): + from .nexus_gen_ar_model import Qwen2_5_VLProcessor + self.processor = Qwen2_5_VLProcessor.from_pretrained(path) + + + @staticmethod + def state_dict_converter(): + return NexusGenAutoregressiveModelStateDictConverter() + + def bound_image(self, image, max_pixels=262640): + from qwen_vl_utils import smart_resize + resized_height, resized_width = smart_resize( + image.height, + image.width, + max_pixels=max_pixels, + ) + return image.resize((resized_width, resized_height)) + + def get_editing_msg(self, instruction): + if '' not in instruction: + instruction = ' ' + instruction + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: "}] + return messages + + def get_generation_msg(self, instruction): + instruction = "Generate an image according to the following description: {}".format(instruction) + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] + return messages + + def forward(self, instruction, ref_image=None, num_img_tokens=81): + """ + Generate target embeddings for the given instruction and reference image. + """ + if ref_image is not None: + messages = self.get_editing_msg(instruction) + images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + else: + messages = self.get_generation_msg(instruction) + images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + + return output_image_embeddings + + def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81): + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + text = text.replace('', '<|vision_start|><|image_pad|><|vision_end|>') + inputs = processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + + input_embeds = model.model.embed_tokens(inputs['input_ids']) + image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + image_prefill_embeds = model.image_prefill_embeds( + torch.arange(81, device=model.device).long() + ) + input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) + + position_ids, _ = model.get_rope_index( + inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + return output_image_embeddings, input_image_embeds, inputs['image_grid_thw'] + + +class NexusGenAutoregressiveModelStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict = {"model." + key: value for key, value in state_dict.items()} + return state_dict diff --git a/DiffSynth-Studio/diffsynth/models/nexus_gen_ar_model.py b/DiffSynth-Studio/diffsynth/models/nexus_gen_ar_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b647786aafc5c80245272269d3e0a525e03b2da1 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/nexus_gen_ar_model.py @@ -0,0 +1,1143 @@ +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput +from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.modeling_outputs import ModelOutput +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + QWEN2_5_VL_INPUTS_DOCSTRING, + ) + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + image_embeddings: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + self.image_prefill_embeds = nn.Embedding(81, config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + token_loss_weight: Optional[float] = 0.1, + img_loss_weight: Optional[float] = 1.0, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # test feature + inputs_embeds = self.model.embed_tokens(input_ids) + # for image encoding and training + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + # position_ids [3, B, L] + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + image_embeds = self.vision_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + # prepare labels for logits + logits_labels = labels.clone().detach() + image_tokens = (labels == self.config.image_token_id) + logits_labels[image_tokens] = -100 + + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = logits_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) * token_loss_weight + + shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1) + shifted_image_embeds = image_embeds[:, :-1, :].contiguous() # (B, L-1, D) + masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d] # (num_image_tokens, D) + + mse_loss_fct = nn.MSELoss() + mse_loss_fct = mse_loss_fct.to(shift_logits.device) + if image_embeddings is None: + image_embeddings = torch.zeros_like(masked_image_embeds) + img_loss = mse_loss_fct(masked_image_embeds, image_embeddings) + + cos_sim = torch.cosine_similarity( + masked_image_embeds, + image_embeddings, + dim=-1 + ) + cos_loss = (1 - cos_sim).mean() + img_loss = 0.5 * img_loss + 0.5 * cos_loss + # fix nan for empty image tokens + if image_embeddings.size(0) == 0: + img_loss = img_loss.nan_to_num(0.0) + # combine the loss + loss = loss + img_loss_weight * img_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + image_embeddings=image_embeds, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + model_forward = self.__call__ + if isinstance(model_kwargs.get("past_key_values"), Cache): + is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache + is_compileable = is_compileable and not self.generation_config.disable_compile + if is_compileable and ( + self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices + ): + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + is_prefill = True + is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id + generation_image_grid_thw = model_kwargs.pop("generation_image_grid_thw", self.get_default_image_grid_thw()) + num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) + output_image_embeddings = [] + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare prefilled embeds + model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs)) + + # parse position_ids from model_kwargs + model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs)) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + # TODO: support batch image sampling + if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens: + output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1)) + + if synced_gpus and this_peer_finished: + continue + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + + # do not sample token + next_token_logits[:, self.config.vision_end_token_id] = -float('inf') + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id): + # probs[:, self.config.vision_end_token_id] = 0 + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + #TODO: support batch image sample + if num_img_tokens is not None: + cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1) + # check whether is sampling images + is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img) + is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens) + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to end sampling images + next_tokens[is_end_img] = self.config.vision_end_token_id + else: + # check whether to end sampling images + is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id)) + # replace the next token with the image token if is sampling image + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to start sampling images + is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id)) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + # output the image embeddings + output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None + + if return_dict_in_generate: + return GenerateDecoderOnlyAll2AllOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + output_image_embeddings=output_image_embeddings, + ) + else: + return input_ids + + + def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs): + if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img): + return {} + # TODO: support batch image sample + image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0) + inputs_embeds = self.image_prefill_embeds(image_idx) + return {"inputs_embeds": inputs_embeds} + + + def get_default_image_grid_thw(self,): + return torch.tensor([[1, 18, 18]]).to(self.device) + + + def get_num_image_tokens(self, image_grid_thw): + return int(torch.prod(image_grid_thw, dim=1).sum() // 4) + + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + num_img_tokens = model_kwargs.pop("generation_image_grid_thw", None) + super()._validate_model_kwargs(model_kwargs) + model_kwargs["generation_image_grid_thw"] = num_img_tokens + + def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs): + # Overwritten -- prepare position_ids for image tokens + cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)) + # TODO: support batch image sample + if cur_img_tokens > 0 and bool(is_sampling_img): + image_grid_thw = generation_image_grid_thw + if model_kwargs.get('image_grid_thw') is not None: + image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw]) + remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens + padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id) + padded_ids = torch.cat([input_ids, padding_ids], dim=1) + position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None) + if model_kwargs.get("use_cache", True): + position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1) + else: + position_ids = position_ids[:, :, :input_ids.shape[1]] + return {"position_ids": position_ids} + return {} + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + image_embeddings=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepared with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] + + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def batch_decode_all2all(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + decoded = self.tokenizer.batch_decode(*args, **kwargs) + pattern = r'<\|vision_start\|>.*?<\|vision_end\|>' + decoded_with_image_tag = [re.sub(pattern, '', d, flags=re.DOTALL) for d in decoded] + decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag] + return decoded_with_image_tag + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] diff --git a/DiffSynth-Studio/diffsynth/models/nexus_gen_projector.py b/DiffSynth-Studio/diffsynth/models/nexus_gen_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..d69b3e1bfd50fc3b9c098f7775afae4020f9b320 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/nexus_gen_projector.py @@ -0,0 +1,417 @@ +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple + + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + from transformers.modeling_rope_utils import _compute_default_rope_parameters + self.rope_init_fn = _compute_default_rope_parameters + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NexusGenImageEmbeddingMerger(nn.Module): + def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): + super().__init__() + from transformers import Qwen2_5_VLConfig + from transformers.activations import ACT2FN + config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.config = config + self.num_layers = num_layers + self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)]) + self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, out_channel * expand_ratio), + Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps), + ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel), + Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps)) + self.base_grid = torch.tensor([[1, 72, 72]], device=device) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device) + + def get_position_ids(self, image_grid_thw): + """ + Generates position ids for the input embeddings grid. + modified from the qwen2_vl mrope. + """ + batch_size = image_grid_thw.shape[0] + spatial_merge_size = self.config.vision_config.spatial_merge_size + t, h, w = ( + image_grid_thw[0][0], + image_grid_thw[0][1], + image_grid_thw[0][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + scale_h = self.base_grid[0][1].item() / h.item() + scale_w = self.base_grid[0][2].item() / w.item() + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + time_tensor = expanded_range * self.config.vision_config.tokens_per_second + t_index = time_tensor.long().flatten().to(image_grid_thw.device) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w + # 3, B, L + position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2) + return position_ids + + def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None): + position_ids = self.get_position_ids(embeds_grid) + hidden_states = embeds + if ref_embeds is not None: + position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid) + position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1) + hidden_states = torch.cat((embeds, ref_embeds), dim=1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + + hidden_states = self.projector(hidden_states) + return hidden_states + + @staticmethod + def state_dict_converter(): + return NexusGenMergerStateDictConverter() + + +class NexusGenMergerStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')} + return merger_state_dict + + +class NexusGenAdapter(nn.Module): + """ + Adapter for Nexus-Gen generation decoder. + """ + def __init__(self, input_dim=3584, output_dim=4096): + super(NexusGenAdapter, self).__init__() + self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim), + nn.LayerNorm(output_dim), nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.LayerNorm(output_dim)) + + def forward(self, x): + return self.adapter(x) + + @staticmethod + def state_dict_converter(): + return NexusGenAdapterStateDictConverter() + + +class NexusGenAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')} + return adapter_state_dict diff --git a/DiffSynth-Studio/diffsynth/models/qwen_image_controlnet.py b/DiffSynth-Studio/diffsynth/models/qwen_image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce40809065b3eac020d2b1da29101681a44764a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/qwen_image_controlnet.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from .general_modules import RMSNorm + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072): + super().__init__() + self.x_rms = RMSNorm(dim, eps=1e-6) + self.y_rms = RMSNorm(dim, eps=1e-6) + self.input_proj = nn.Linear(dim, dim) + self.act = nn.GELU() + self.output_proj = nn.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + def init_weights(self): + # zero initialize output_proj + nn.init.zeros_(self.output_proj.weight) + nn.init.zeros_(self.output_proj.bias) + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + ): + super().__init__() + self.img_in = nn.Linear(in_dim + additional_in_dim, dim) + self.controlnet_blocks = nn.ModuleList( + [ + BlockWiseControlBlock(dim) + for _ in range(num_layers) + ] + ) + + def init_weight(self): + nn.init.zeros_(self.img_in.weight) + nn.init.zeros_(self.img_in.bias) + for block in self.controlnet_blocks: + block.init_weights() + + def process_controlnet_conditioning(self, controlnet_conditioning): + return self.img_in(controlnet_conditioning) + + def blockwise_forward(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) diff --git a/DiffSynth-Studio/diffsynth/models/qwen_image_dit.py b/DiffSynth-Studio/diffsynth/models/qwen_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd51439c67803ce68500cdfe18e0795a4c4a929 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/qwen_image_dit.py @@ -0,0 +1,685 @@ +import torch, math, functools +import torch.nn as nn +from typing import Tuple, Optional, Union, List +from einops import rearrange +from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + + +def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False): + if FLASH_ATTN_3_AVAILABLE and attention_mask is None: + if not enable_fp8_attention: + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + origin_dtype = q.dtype + q_std, k_std, v_std = q.std(), k.std(), v.std() + q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn) + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1))) + if isinstance(x, tuple): + x = x[0] + x = x.to(origin_dtype) * v_std + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] +): + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + self.rope_cache = {} + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer( + index, + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + + def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): + if isinstance(video_fhw, list): + video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3)) + _, height, width = video_fhw + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + required_len = max_vid_index + max(txt_seq_lens) + cur_max_len = self.pos_freqs.shape[0] + if required_len <= cur_max_len: + return + + new_max_len = math.ceil(required_len / 512) * 512 + pos_index = torch.arange(new_max_len) + neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + return + + + def forward(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + + def forward_sampling(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache: + frame_0, height_0, width_0 = video_fhw[0] + + rope_key_0 = f"0_{height_0}_{width_0}" + spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1) + h_indices = torch.linspace(0, height_0 - 1, height).long() + w_indices = torch.linspace(0, width_0 - 1, width).long() + h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij') + sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :] + + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame + + seq_lens = frame * height * width + self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone() + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone() + vid_freqs.append(self.rope_cache[rope_key].contiguous()) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + +class QwenEmbedLayer3DRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + video_fhw = [video_fhw] + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + layer_num = len(video_fhw) - 1 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + if idx != layer_num: + video_freq = self._compute_video_freqs(frame, height, width, idx) + else: + ### For the condition image, we set the layer index to -1 + video_freq = self._compute_condition_freqs(frame, height, width) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_vid_index = max(max_vid_index, layer_num) + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + @functools.lru_cache(maxsize=None) + def _compute_condition_freqs(self, frame, height, width): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenFeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dropout: float = 0.0, + ): + super().__init__() + inner_dim = int(dim * 4) + self.net = nn.ModuleList([]) + self.net.append(ApproximateGELU(dim, inner_dim)) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class QwenDoubleStreamAttention(nn.Module): + def __init__( + self, + dim_a, + dim_b, + num_heads, + head_dim, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = nn.Linear(dim_a, dim_a) + self.to_k = nn.Linear(dim_a, dim_a) + self.to_v = nn.Linear(dim_a, dim_a) + self.norm_q = RMSNorm(head_dim, eps=1e-6) + self.norm_k = RMSNorm(head_dim, eps=1e-6) + + self.add_q_proj = nn.Linear(dim_b, dim_b) + self.add_k_proj = nn.Linear(dim_b, dim_b) + self.add_v_proj = nn.Linear(dim_b, dim_b) + self.norm_added_q = RMSNorm(head_dim, eps=1e-6) + self.norm_added_k = RMSNorm(head_dim, eps=1e-6) + + self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a)) + self.to_add_out = nn.Linear(dim_b, dim_b) + + def forward( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + enable_fp8_attention: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) + txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) + seq_txt = txt_q.shape[1] + + img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads) + img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads) + img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads) + + txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads) + txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads) + txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads) + + img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) + txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) + + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + + joint_q = torch.cat([txt_q, img_q], dim=2) + joint_k = torch.cat([txt_k, img_k], dim=2) + joint_v = torch.cat([txt_v, img_v], dim=2) + + joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype) + + txt_attn_output = joint_attn_out[:, :seq_txt, :] + img_attn_output = joint_attn_out[:, seq_txt:, :] + + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim), + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = QwenDoubleStreamAttention( + dim_a=dim, + dim_b=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim) + + def _modulate(self, x, mod_params, index=None): + shift, scale, gate = mod_params.chunk(3, dim=-1) + if index is not None: + # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + # So shift, scale, gate have shape [2*actual_batch, d] + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + + # index: [b, l] where b is actual batch size + # Expand to [b, l, 1] to match feature dimension + index_expanded = index.unsqueeze(-1) # [b, l, 1] + + # Expand chunks to [b, 1, d] then broadcast to [b, l, d] + shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] + shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + + # Use torch.where to select based on index + shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + + return x * (1 + scale_result) + shift_result, gate_result + + def forward( + self, + image: torch.Tensor, + text: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + enable_fp8_attention = False, + modulate_index: Optional[List[int]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + if modulate_index is not None: + temb = torch.chunk(temb, 2, dim=0)[0] + txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + + img_normed = self.img_norm1(image) + img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index) + + txt_normed = self.txt_norm1(text) + txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) + + img_attn_out, txt_attn_out = self.attn( + image=img_modulated, + text=txt_modulated, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, + ) + + image = image + img_gate * img_attn_out + text = text + txt_gate * txt_attn_out + + img_normed_2 = self.img_norm2(image) + img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index) + + txt_normed_2 = self.txt_norm2(text) + txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + + img_mlp_out = self.img_mlp(img_modulated_2) + txt_mlp_out = self.txt_mlp(txt_modulated_2) + + image = image + img_gate_2 * img_mlp_out + text = text + txt_gate_2 * txt_mlp_out + + return text, image + + +class QwenImageDiT(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + use_layer3d_rope: bool = False, + use_additional_t_cond: bool = False, + ): + super().__init__() + + if not use_layer3d_rope: + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + else: + self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + + self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond) + self.txt_norm = RMSNorm(3584, eps=1e-6) + + self.img_in = nn.Linear(64, 3072) + self.txt_in = nn.Linear(3584, 3072) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = AdaLayerNorm(3072, single=True) + self.proj_out = nn.Linear(3072, 64) + + + def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes): + # prompt_emb + all_prompt_emb = entity_prompt_emb + [prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] + all_prompt_emb = torch.cat(all_prompt_emb, dim=1) + + # image_rotary_emb + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask] + entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] + txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + + # attention_mask + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype) + entity_masks = entity_masks + [global_mask] + + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + total_seq_len = sum(seq_lens) + image.shape[1] + patched_masks = [] + for i in range(N): + patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + patched_masks.append(patched_mask) + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + # prompt-image attention mask + image_start = sum(seq_lens) + image_end = total_seq_len + cumsum = [0] + single_image_seq = image_end - image_start + for length in seq_lens: + cumsum.append(cumsum[-1] + length) + for i in range(N): + prompt_start = cumsum[i] + prompt_end = cumsum[i+1] + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) + # repeat image mask to match the single image sequence length + repeat_time = single_image_seq // image_mask.shape[-1] + image_mask = image_mask.repeat(1, 1, repeat_time) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt attention mask, let the prompt tokens not attend to each other + for i in range(N): + for j in range(N): + if i == j: + continue + start_i, end_i = cumsum[i], cumsum[i+1] + start_j, end_j = cumsum[j], cumsum[j+1] + attention_mask[:, start_i:end_i, start_j:end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + + return all_prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + ): + img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + + image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image = self.img_in(image) + text = self.txt_in(self.txt_norm(prompt_emb)) + + conditioning = self.time_text_embed(timestep, image.dtype) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + + for block in self.transformer_blocks: + text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + ) + + image = self.norm_out(image, conditioning) + image = self.proj_out(image) + + latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) + return image diff --git a/DiffSynth-Studio/diffsynth/models/qwen_image_image2lora.py b/DiffSynth-Studio/diffsynth/models/qwen_image_image2lora.py new file mode 100644 index 0000000000000000000000000000000000000000..6aefbf25de6ccdb37de2d2d44e644fb77952b570 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/qwen_image_image2lora.py @@ -0,0 +1,128 @@ +import torch + + +class CompressedMLP(torch.nn.Module): + def __init__(self, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias) + + def forward(self, x, residual=None): + x = self.proj_in(x) + if residual is not None: x = x + residual + x = self.proj_out(x) + return x + + +class ImageEmbeddingToLoraMatrix(torch.nn.Module): + def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank): + super().__init__() + self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank) + self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank) + self.lora_a_dim = lora_a_dim + self.lora_b_dim = lora_b_dim + self.rank = rank + + def forward(self, x, residual=None): + lora_a = self.proj_a(x, residual).view(self.rank, self.lora_a_dim) + lora_b = self.proj_b(x, residual).view(self.lora_b_dim, self.rank) + return lora_a, lora_b + + +class SequencialMLP(torch.nn.Module): + def __init__(self, length, in_dim, mid_dim, out_dim, bias=False): + super().__init__() + self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) + self.proj_out = torch.nn.Linear(length * mid_dim, out_dim, bias=bias) + self.length = length + self.in_dim = in_dim + self.mid_dim = mid_dim + + def forward(self, x): + x = x.view(self.length, self.in_dim) + x = self.proj_in(x) + x = x.view(1, self.length * self.mid_dim) + x = self.proj_out(x) + return x + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"transformer_blocks.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"transformer_blocks.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class QwenImageImage2LoRAModel(torch.nn.Module): + def __init__(self, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = [ + [ + ("attn.to_q", 3072, 3072), + ("attn.to_k", 3072, 3072), + ("attn.to_v", 3072, 3072), + ("attn.to_out.0", 3072, 3072), + ], + [ + ("img_mlp.net.2", 3072*4, 3072), + ("img_mod.1", 3072, 3072*6), + ], + [ + ("attn.add_q_proj", 3072, 3072), + ("attn.add_k_proj", 3072, 3072), + ("attn.add_v_proj", 3072, 3072), + ("attn.to_add_out", 3072, 3072), + ], + [ + ("txt_mlp.net.2", 3072*4, 3072), + ("txt_mod.1", 3072, 3072*6), + ], + ] + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) diff --git a/DiffSynth-Studio/diffsynth/models/qwen_image_text_encoder.py b/DiffSynth-Studio/diffsynth/models/qwen_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f19d2d8ae2a61fd9cc45414a6bac18d28e4edcc9 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/qwen_image_text_encoder.py @@ -0,0 +1,190 @@ +import torch +from typing import Optional, Union + + +class QwenImageTextEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel + config = Qwen2_5_VLConfig(**{ + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "text_config": { + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": None, + "initializer_range": 0.02, + "intermediate_size": 18944, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl_text", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": None, + "torch_dtype": "float32", + "use_cache": True, + "use_sliding_window": False, + "video_token_id": None, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }, + "tie_word_embeddings": False, + "torch_dtype": "float32", + "transformers_version": "4.54.0", + "use_cache": True, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "depth": 32, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 3420, + "model_type": "qwen2_5_vl", + "num_heads": 16, + "out_hidden_size": 3584, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + "tokens_per_second": 2, + "torch_dtype": "float32", + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLModel(config) + self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.config = config + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ): + output_attentions = False + output_hidden_states = True + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states diff --git a/DiffSynth-Studio/diffsynth/models/qwen_image_vae.py b/DiffSynth-Studio/diffsynth/models/qwen_image_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..2845354f24ad68214fe3f0ca1264ea9450ab16f3 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/qwen_image_vae.py @@ -0,0 +1,726 @@ +import torch +from typing import List, Optional, Tuple, Union +from torch import nn + + +CACHE_T = 2 + +class QwenImageCausalConv3d(torch.nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = torch.nn.functional.pad(x, padding) + return super().forward(x) + + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = torch.nn.SiLU() + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + image_channels=3 + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = torch.nn.SiLU() + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = torch.nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + image_channels=3, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = torch.nn.SiLU() + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + + +class QwenImageVAE(torch.nn.Module): + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + image_channels: int = 3, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels, + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels, + ) + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1) + self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1) + + def encode(self, x, **kwargs): + x = x.unsqueeze(2) + x = self.encoder(x) + x = self.quant_conv(x) + x = x[:, :16] + mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device) + x = (x - mean) * std + x = x.squeeze(2) + return x + + def decode(self, x, **kwargs): + x = x.unsqueeze(2) + mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device) + x = x / std + mean + x = self.post_quant_conv(x) + x = self.decoder(x) + x = x.squeeze(2) + return x diff --git a/DiffSynth-Studio/diffsynth/models/sd_text_encoder.py b/DiffSynth-Studio/diffsynth/models/sd_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a1171c265a43048c1623bd0c2375f4fc3f5e5d --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/sd_text_encoder.py @@ -0,0 +1,412 @@ +import torch +from .attention import Attention +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/models/siglip2_image_encoder.py b/DiffSynth-Studio/diffsynth/models/siglip2_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..509eff40946699dffcb31125916fc7acfad0caa0 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/siglip2_image_encoder.py @@ -0,0 +1,134 @@ +from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig +from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast +import torch + +from diffsynth.core.device.npu_compatible_device import get_device_type + + +class Siglip2ImageEncoder(SiglipVisionTransformer): + def __init__(self): + config = SiglipVisionConfig( + attention_dropout = 0.0, + dtype = "float32", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1536, + image_size = 384, + intermediate_size = 6144, + layer_norm_eps = 1e-06, + model_type = "siglip_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 40, + patch_size = 16, + transformers_version = "4.56.1", + _attn_implementation = "sdpa" + ) + super().__init__(config) + self.processor = SiglipImageProcessor( + do_convert_rgb = None, + do_normalize = True, + do_rescale = True, + do_resize = True, + image_mean = [ + 0.5, + 0.5, + 0.5 + ], + image_processor_type = "SiglipImageProcessor", + image_std = [ + 0.5, + 0.5, + 0.5 + ], + processor_class = "SiglipProcessor", + resample = 2, + rescale_factor = 0.00392156862745098, + size = { + "height": 384, + "width": 384 + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()): + pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] + pixel_values = pixel_values.to(device=device, dtype=torch_dtype) + output_attentions = False + output_hidden_states = False + interpolate_pos_encoding = False + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return pooler_output + + +class Siglip2ImageEncoder428M(Siglip2VisionModel): + def __init__(self): + config = Siglip2VisionConfig( + attention_dropout = 0.0, + dtype = "bfloat16", + hidden_act = "gelu_pytorch_tanh", + hidden_size = 1152, + intermediate_size = 4304, + layer_norm_eps = 1e-06, + model_type = "siglip2_vision_model", + num_attention_heads = 16, + num_channels = 3, + num_hidden_layers = 27, + num_patches = 256, + patch_size = 16, + transformers_version = "4.57.1" + ) + super().__init__(config) + self.processor = Siglip2ImageProcessorFast( + **{ + "data_format": "channels_first", + "default_to_square": True, + "device": None, + "disable_grouping": None, + "do_convert_rgb": None, + "do_normalize": True, + "do_pad": None, + "do_rescale": True, + "do_resize": True, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_processor_type": "Siglip2ImageProcessorFast", + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "input_data_format": None, + "max_num_patches": 256, + "pad_size": None, + "patch_size": 16, + "processor_class": "Siglip2Processor", + "resample": 2, + "rescale_factor": 0.00392156862745098, + "return_tensors": None, + } + ) + + def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = super().forward(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + hidden_state = hidden_state.to(torch_dtype) + return hidden_state diff --git a/DiffSynth-Studio/diffsynth/models/step1x_connector.py b/DiffSynth-Studio/diffsynth/models/step1x_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..225c8fbcb54f8daebf48656141e9a8998d002fd8 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/step1x_connector.py @@ -0,0 +1,663 @@ +from typing import Optional + +import torch, math +import torch.nn +from einops import rearrange +from torch import nn +from functools import partial +from einops import rearrange + + + +def attention(q, k, v, attn_mask, mode="torch"): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "b n s d -> b s (n d)") + return x + + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, device=device, dtype=dtype) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear( + in_features=in_channels, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + self.act_1 = act_layer() + self.linear_2 = nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=True, + **factory_kwargs, + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, hidden_size, bias=True, **factory_kwargs + ), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore + nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding( + t, self.frequency_embedding_size, self.max_period + ).type(t.dtype) # type: ignore + t_emb = self.mlp(t_freq) + return t_emb + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + +class IndividualTokenRefinerBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + + if self.need_CA: + self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs,) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + if self.need_CA: + x = self.cross_attnblock(x, c, attn_mask, y) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + + + +class CrossAttnBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.norm1_2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_q = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + self.self_attn_kv = nn.Linear( + hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor=None, + + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + norm_y = self.norm1_2(y) + q = self.self_attn_q(norm_x) + q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) + kv = self.self_attn_kv(norm_y) + k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + return x + + + +class IndividualTokenRefiner(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.need_CA = need_CA + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=self.need_CA, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + y:torch.Tensor=None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( + 1, 1, seq_len, 1 + ) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + + for block in self.blocks: + x = block(x, c, self_attn_mask,y) + + return x + + +class SingleTokenRefiner(torch.nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + need_CA:bool=False, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + self.need_CA = need_CA + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + if self.need_CA: + self.input_embedder_CA = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection( + in_channels, hidden_size, act_layer, **factory_kwargs + ) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + need_CA=need_CA, + **factory_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + y: torch.LongTensor=None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + if self.need_CA: + y = self.input_embedder_CA(y) + x = self.individual_token_refiner(x, c, mask, y) + else: + x = self.individual_token_refiner(x, c, mask) + + return x + + +class Qwen2Connector(torch.nn.Module): + def __init__( + self, + # biclip_dim=1024, + in_channels=3584, + hidden_size=4096, + heads_num=32, + depth=2, + need_CA=False, + device=None, + dtype=torch.bfloat16, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype":dtype} + + self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs) + self.global_proj_out=nn.Linear(in_channels,768) + + self.scale_factor = nn.Parameter(torch.zeros(1)) + with torch.no_grad(): + self.scale_factor.data += -(1 - 0.09) + + def forward(self, x,t,mask): + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + x_mean = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device)) + + global_out=self.global_proj_out(x_mean) + encoder_hidden_states = self.S(x,t,mask) + return encoder_hidden_states,global_out diff --git a/DiffSynth-Studio/diffsynth/models/step1x_text_encoder.py b/DiffSynth-Studio/diffsynth/models/step1x_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d144236a33c923b8650057b420681ea48006efd --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/step1x_text_encoder.py @@ -0,0 +1,195 @@ +import torch +from typing import Optional, Union +from .qwen_image_text_encoder import QwenImageTextEncoder +from ..core.device.npu_compatible_device import get_device_type, get_torch_device + + +class Step1xEditEmbedder(torch.nn.Module): + def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()): + super().__init__() + self.max_length = max_length + self.dtype = dtype + self.device = device + + Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + self.prefix = Qwen25VL_7b_PREFIX + self.model = model + self.processor = processor + + def model_forward( + self, + model: QwenImageTextEncoder, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states + ) + + outputs = model.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states + + def forward(self, caption, ref_images): + text_list = caption + embs = torch.zeros( + len(text_list), + self.max_length, + self.model.config.hidden_size, + dtype=torch.bfloat16, + device=get_torch_device().current_device(), + ) + masks = torch.zeros( + len(text_list), + self.max_length, + dtype=torch.long, + device=get_torch_device().current_device(), + ) + + def split_string(s): + s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): + + messages = [{"role": "user", "content": []}] + + messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) + + messages[0]["content"].append({"type": "image", "image": imgs}) + + # 再添加 text + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + + image_inputs = [imgs] + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each = txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:, 1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type()) + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = ( + torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) + .unsqueeze(0) + .to(get_device_type()) + ) + inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type()) + outputs = self.model_forward( + self.model, + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values.to(get_device_type()), + image_grid_thw=inputs.image_grid_thw.to(get_device_type()), + output_hidden_states=True, + ) + + emb = outputs[-1] + + embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ + : self.max_length + ] + + masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( + (min(self.max_length, emb.shape[1] - 217)), + dtype=torch.long, + device=get_torch_device().current_device(), + ) + + return embs, masks diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_animate_adapter.py b/DiffSynth-Studio/diffsynth/models/wan_video_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..3ace70d8b3162e77844f0957cd40207a54e674a9 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_animate_adapter.py @@ -0,0 +1,650 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import Tuple, Optional, List +from einops import rearrange + + + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) H N D") + v = rearrange(v, "B L N H D -> (B L) H N D") + + q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp) + # Compute attention. + attn = F.scaled_dot_product_attention(q, k, v) + + attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.kernel = torch.nn.Parameter(kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + motion = self.dec.direction(motion_feat) + return motion + + +class WanAnimateAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5) + self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4) + + def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec + + def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): + if block_idx % 5 == 0: + adapter_args = [x, motion_vec, motion_masks, False] + residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) + x = residual_out + x + return x diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_camera_controller.py b/DiffSynth-Studio/diffsynth/models/wan_video_camera_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..45a44ee6bcd408d7ee9d18653f933151ce351a72 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_camera_controller.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +import os +from typing_extensions import Literal + +class SimpleAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1): + super(SimpleAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + def process_camera_coordinates( + self, + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], + length: int, + height: int, + width: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + ): + if origin is None: + origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) + coordinates = generate_camera_coordinates(direction, length, speed, origin) + plucker_embedding = process_pose_file(coordinates, width, height) + return plucker_embedding + + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def custom_meshgrid(*args): + # torch>=2.0.0 only + return torch.meshgrid(*args, indexing='ij') + + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.linalg.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + + +def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + + + +def generate_camera_coordinates( + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"], + length: int, + speed: float = 1/54, + origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) +): + coordinates = [list(origin)] + while len(coordinates) < length: + coor = coordinates[-1].copy() + if "Left" in direction: + coor[9] += speed + if "Right" in direction: + coor[9] -= speed + if "Up" in direction: + coor[13] += speed + if "Down" in direction: + coor[13] -= speed + if "In" in direction: + coor[18] -= speed + if "Out" in direction: + coor[18] += speed + coordinates.append(coor) + return coordinates diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_dit.py b/DiffSynth-Studio/diffsynth/models/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9579f53e7f9a999e3e420938cad923c6ec205c --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_dit.py @@ -0,0 +1,467 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Tuple, Optional +from einops import rearrange +from .wan_video_camera_controller import SimpleAdapter + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x,tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + +class FirstFrameAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.q = nn.Linear(dim, dim) # W'_Q + self.norm_q = RMSNorm(dim, eps=eps) + self.o = nn.Linear(dim, dim) # W'_O + + self.attn = AttentionModule(self.num_heads) + + nn.init.zeros_(self.o.weight) + nn.init.zeros_(self.o.bias) + + def forward(self, x, k1, v1): + + q = self.norm_q(self.q(x)) + + out = self.attn(q, k1, v1) + + return self.o(out) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, alpha=1, enable_i2v=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + self.enable_i2v = enable_i2v + self.alpha = alpha + if enable_i2v: + self.i2v_adapter = FirstFrameAttention(dim, num_heads, eps) + + def init_i2v_from_self(self): + + with torch.no_grad(): + self.i2v_adapter.q.weight.copy_(self.q.weight) + self.i2v_adapter.q.bias.copy_(self.q.bias) + self.i2v_adapter.norm_q.weight.copy_(self.norm_q.weight) + + + def forward(self, x, freqs, first_frame_len=None): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + attn_out = self.attn(q, k, v) + out = self.o(attn_out) + + if self.enable_i2v: + + x1 = x[:, :first_frame_len] + + k1 = self.norm_k(self.k(x1)) + v1 = self.v(x1) + + extra = self.i2v_adapter(x, k1, v1) + + start = first_frame_len + end = out.shape[1] // 2 + + out[:, start:end] += self.alpha * extra[:, start:end] + + return out + + +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = self.attn(q, k, v) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention( + dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs, first_frame_len=None): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) + + + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs, first_frame_len)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + has_image_pos_emb: bool = False, + has_ref_conv: bool = False, + add_control_adapter: bool = False, + in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if has_ref_conv: + self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) + self.has_image_pos_emb = has_image_pos_emb + self.has_ref_conv = has_ref_conv + if add_control_adapter: + self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) + else: + self.control_adapter = None + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): + + + x = self.patch_embedding(x) + + if self.control_adapter is not None and control_camera_latents_input is not None: + y_camera = self.control_adapter(control_camera_latents_input) + x = [u + v for u, v in zip(x, y_camera)] + x = x[0].unsqueeze(0) + return x + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_dit_s2v.py b/DiffSynth-Studio/diffsynth/models/wan_video_dit_s2v.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbed8c05a2b5acca4c1ecaa2a68100433c74366 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_dit_s2v.py @@ -0,0 +1,594 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple +from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d + + +def torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item() + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1 + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + if need_global: + self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class FramePackMotioner(nn.Module): + + def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split( + list(self.zip_frame_buckets)[::-1], dim=2 + ) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + ] + ] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads), + grid_sizes, + self.freqs, + start=None + ) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + + +class AdaLayerNorm(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_dim: int, + norm_eps: float = 1e-5, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False) + + def forward(self, x, temb): + temb = self.linear(F.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__( + self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + enable_adain=False, + adain_dim=2048, + ): + super().__init__() + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, DiTBlock): + for inject_id in inject_layer: + if f'transformer_blocks.{inject_id}' in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([CrossAttention( + dim=dim, + num_heads=num_heads, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)]) + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False): + super().__init__() + self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(self.weights.to(device=features.device, dtype=features.dtype)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class WanS2VDiTBlock(DiTBlock): + + def forward(self, x, context, t_mod, seq_len_x, freqs): + t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. + t_mod = [ + torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + for element in t_mod + ] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + return x + + +class WanS2VModel(torch.nn.Module): + + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + cond_dim: int, + audio_dim: int, + num_audio_token: int, + enable_adain: bool = True, + audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + zero_timestep: bool = True, + add_last_motion: bool = True, + framepack_drop_mode: str = "padd", + fuse_vae_embedding_in_latents: bool = True, + require_vae_embedding: bool = False, + seperated_timestep: bool = False, + require_clip_embedding: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.patch_size = patch_size + self.num_heads = num_heads + self.enbale_adain = enable_adain + self.add_last_motion = add_last_motion + self.zero_timestep = zero_timestep + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + self.require_vae_embedding = require_vae_embedding + self.seperated_timestep = seperated_timestep + self.require_clip_embedding = require_clip_embedding + + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = Head(dim, out_dim, patch_size, eps) + self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1) + + self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) + self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) + all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") + self.audio_injector = AudioInjector_WAN( + all_modules, + all_modules_names, + dim=dim, + num_heads=num_heads, + inject_layer=audio_inject_layers, + enable_adain=enable_adain, + adain_dim=dim, + ) + self.trainable_cond_mask = nn.Embedding(3, dim) + self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode) + + def patchify(self, x: torch.Tensor): + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, + 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], + h=grid_size[1], + w=grid_size[2], + x=self.patch_size[0], + y=self.patch_size[1], + z=self.patch_size[2] + ) + + def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] + else: + return flattern_mot, mot_remb + + def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): + # inject the motion frames token to the hidden states + mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) + if len(mot) > 0: + x = torch.cat([x, mot[0]], dim=1) + rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1) + mask_input = torch.cat( + [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1 + ) + return x, rope_embs, mask_input + + def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + num_frames = audio_emb.shape[1] + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sp_group + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + + input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c + input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + return hidden_states + + def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input) + audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() + merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + return audio_emb_global, merged_audio_emb + + def get_grid_sizes(self, grid_size_x, grid_size_ref): + f, h, w = grid_size_x + rf, rh, rw = grid_size_ref + grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) + grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]] + grid_sizes_ref = [[ + torch.tensor([30, 0, 0]).unsqueeze(0), + torch.tensor([31, rh, rw]).unsqueeze(0), + torch.tensor([1, rh, rw]).unsqueeze(0), + ]] + return grid_sizes_x + grid_sizes_ref + + def forward( + self, + latents, + timestep, + context, + audio_input, + motion_latents, + pose_cond, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False + ): + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = self.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] + + # reference image + ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute( + x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + ) + # motion + x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + + x = x + self.trainable_cond_mask(mask).to(x.dtype) + + # t_mod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(self.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + + x = x[:, :seq_len_x] + x = self.head(x, t[:-1]) + x = self.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_image_encoder.py b/DiffSynth-Studio/diffsynth/models/wan_video_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..37d17d6a183f5d6290c3f4b1e417bf0310bfa353 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_image_encoder.py @@ -0,0 +1,878 @@ +""" +Concise re-implementation of +``https://github.com/openai/CLIP'' and +``https://github.com/mlfoundations/open_clip''. +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from .wan_video_dit import flash_attention + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init model + if pretrained: + from sora import DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = XLMRoberta(**cfg) + + # load checkpoint + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'), + map_location=device), + assign=True) + else: + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + + # init tokenizer + if return_tokenizer: + from sora.data import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.text_len, + clean='whitespace') + return model, tokenizer + else: + return model + + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + # compute query, key, value + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # compute attention + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1) + k, v = self.to_kv(x).chunk(2, dim=-1) + + # compute attention + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + e = e.to(dtype=x.dtype, device=x.device) + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_mlp_ratio=4, + vision_heads=12, + vision_layers=12, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + vocab_size=49408, + text_len=77, + text_dim=512, + text_mlp_ratio=4, + text_heads=8, + text_layers=12, + text_causal=True, + text_pool='argmax', + text_head_bias=False, + logit_bias=None, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pool = vision_pool + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_mlp_ratio = text_mlp_ratio + self.text_heads = text_heads + self.text_layers = text_layers + self.text_causal = text_causal + self.text_pool = text_pool + self.text_head_bias = text_head_bias + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + mlp_ratio=text_mlp_ratio, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + causal=text_causal, + pool_type=text_pool, + head_bias=text_head_bias, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + if logit_bias is not None: + self.logit_bias = nn.Parameter(logit_bias * torch.ones([])) + + # initialize weights + self.init_weights() + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, std=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else self.text_dim + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * len(transformer))) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = None + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=CLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init model + if pretrained and pretrained_name: + from sora import BUCKET, DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = model_cls(**kwargs) + + # checkpoint path + checkpoint = f'models/clip/{pretrained_name}' + if dtype in (torch.float16, torch.bfloat16): + suffix = '-' + { + torch.float16: 'fp16', + torch.bfloat16: 'bf16' + }[dtype] + if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'): + checkpoint = f'{checkpoint}{suffix}' + checkpoint += '.pth' + + # load + model.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device), + assign=True, + strict=False) + else: + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + + # init tokenizer + if return_tokenizer: + from sora import data + if 'siglip' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name=f'timm/{pretrained_name}', + seq_len=model.text_len, + clean='canonicalize') + elif 'xlm' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.max_text_len - 2, + clean='whitespace') + elif 'mba' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='facebook/xlm-roberta-xl', + seq_len=model.max_text_len - 2, + clean='whitespace') + else: + tokenizer = data.CLIPTokenizer( + seq_len=model.text_len, padding=tokenizer_padding) + output += (tokenizer,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class WanImageEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=torch.float32, + device="cpu") + + def encode_image(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u, + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_mot.py b/DiffSynth-Studio/diffsynth/models/wan_video_mot.py new file mode 100644 index 0000000000000000000000000000000000000000..4091c91777355dce91ccefac56679f8b936e7abb --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_mot.py @@ -0,0 +1,169 @@ +import torch +from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP +import einops +import torch.nn as nn + + +class MotSelfAttention(SelfAttention): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__(dim, num_heads, eps) + def forward(self, x, freqs, is_before_attn=False): + if is_before_attn: + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + return q, k, v + else: + return self.o(x) + + +class MotWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + + self.self_attn = MotSelfAttention(dim, num_heads, eps) + + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot): + + # 1. prepare scale parameter + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + + scale_params_mot_ref = self.modulation + t_mod_mot.float() + scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1) + shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2) + + # 2. Self-attention + input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa) + # original block self-attn + attn1 = wan_block.self_attn + q = attn1.norm_q(attn1.q(input_x)) + k = attn1.norm_k(attn1.k(input_x)) + v = attn1.v(input_x) + q = rope_apply(q, freqs, attn1.num_heads) + k = rope_apply(k, freqs, attn1.num_heads) + + # mot block self-attn + norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot) + norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1) + q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True) + + tmp_hidden_states = flash_attention( + torch.cat([q, q_mot], dim=-2), + torch.cat([k, k_mot], dim=-2), + torch.cat([v, v_mot], dim=-2), + num_heads=attn1.num_heads) + + attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2) + + attn_output = attn1.o(attn_output) + x = wan_block.gate(x, gate_msa, attn_output) + + attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False) + # gate + attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1) + attn_output_mot = attn_output_mot * gate_msa_mot_ref + attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot) + + # 3. cross-attention and feed-forward + x = x + wan_block.cross_attn(wan_block.norm3(x), context) + input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp) + x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x)) + + x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot) + # modulate + norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1) + norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot) + norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1) + input_x_mot = self.ffn(norm_x_mot_ref) + # gate + input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1) + input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref + input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1) + x_mot = (x_mot.float() + input_x_mot).type_as(x_mot) + + return x, x_mot + + +class MotWanModel(torch.nn.Module): + def __init__( + self, + mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + patch_size=(1, 2, 2), + has_image_input=True, + has_image_pos_emb=False, + dim=5120, + num_heads=40, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + in_dim=36, + eps=1e-6, + ): + super().__init__() + self.mot_layers = mot_layers + self.freq_dim = freq_dim + self.dim = dim + + self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)} + self.head_dim = dim // num_heads + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + if has_image_input: + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) + + # mot blocks + self.blocks = torch.nn.ModuleList([ + MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.mot_layers + ]) + + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + return x + + def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0): + def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta) + h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) + + freqs = torch.cat([ + f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1), + h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1), + w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1) + return freqs + + def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id): + block = self.blocks[self.mot_layers_mapping[block_id]] + x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot) + return x, x_mot diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_motion_controller.py b/DiffSynth-Studio/diffsynth/models/wan_video_motion_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..34763a8d76e57bc8efff84f23863938cc2309029 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_motion_controller.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +from .wan_video_dit import sinusoidal_embedding_1d + + + +class WanMotionControllerModel(torch.nn.Module): + def __init__(self, freq_dim=256, dim=1536): + super().__init__() + self.freq_dim = freq_dim + self.linear = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6), + ) + + def forward(self, motion_bucket_id): + emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10) + emb = self.linear(emb) + return emb + + def init(self): + state_dict = self.linear[-1].state_dict() + state_dict = {i: state_dict[i] * 0 for i in state_dict} + self.linear[-1].load_state_dict(state_dict) diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_text_encoder.py b/DiffSynth-Studio/diffsynth/models/wan_video_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..64090db8c65138abfdb60a822b3ba2e74fefeb4c --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_text_encoder.py @@ -0,0 +1,330 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoTokenizer +import ftfy +import html +import string +import regex as re + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class WanTextEncoder(torch.nn.Module): + + def __init__(self, + vocab=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1): + super(WanTextEncoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_vace.py b/DiffSynth-Studio/diffsynth/models/wan_video_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..f3367f788891cb22b8a5bf6eaa50cabbc202ab4a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_vace.py @@ -0,0 +1,87 @@ +import torch +from .wan_video_dit import DiTBlock + + +class VaceWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = torch.nn.Linear(self.dim, self.dim) + self.after_proj = torch.nn.Linear(self.dim, self.dim) + + def forward(self, c, x, context, t_mod, freqs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, context, t_mod, freqs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class VaceWanModel(torch.nn.Module): + def __init__( + self, + vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + vace_in_dim=96, + patch_size=(1, 2, 2), + has_image_input=False, + dim=1536, + num_heads=12, + ffn_dim=8960, + eps=1e-6, + ): + super().__init__() + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # vace blocks + self.vace_blocks = torch.nn.ModuleList([ + VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size) + + def forward( + self, x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ): + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block in self.vace_blocks: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + c, x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + c = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + c, x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + c = block(c, x, context, t_mod, freqs) + hints = torch.unbind(c)[:-1] + return hints diff --git a/DiffSynth-Studio/diffsynth/models/wan_video_vae.py b/DiffSynth-Studio/diffsynth/models/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..d24e29d9398f95a59cbd1466542c6e7059f7c7af --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wan_video_vae.py @@ -0,0 +1,1382 @@ +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +CACHE_T = 2 + + +def check_is_instance(model, module_class): + if isinstance(model, module_class): + return True + if hasattr(model, "module") and isinstance(model.module, module_class): + return True + return False + + +def block_causal_mask(x, block_size): + # params + b, n, s, _, device = *x.size(), x.device + assert s % block_size == 0 + num_blocks = s // block_size + + # build mask + mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) + for i in range(num_blocks): + mask[:, :, + i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1 + return mask + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d(dim, + dim * 2, (3, 1, 1), + padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, + dim, (3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size) + return x + + +class Resample38(Resample): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super(Resample, self).__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = nn.Identity() + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute( + 0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + #attn_mask=block_causal_mask(q, block_size=h * w) + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample38(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False + ): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample38(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Encoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] if i < len(temperal_downsample) else False + ) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + + +class Decoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock(in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1)) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1)) + + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class VideoVAE_(nn.Module): + + def __init__(self, + dim=96, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +class WanVideoVAE(nn.Module): + + def __init__(self, z_dim=16): + super().__init__() + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) + self.upsampling_factor = 8 + self.z_dim = z_dim + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.clamp_(-1, 1) + return values + + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos + + + @staticmethod + def state_dict_converter(): + return WanVideoVAEStateDictConverter() + + +class WanVideoVAEStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ + + +class VideoVAE38_(VideoVAE_): + + def __init__(self, + dim=160, + z_dim=48, + dec_dim=256, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super(VideoVAE_, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + +class WanVideoVAE38(WanVideoVAE): + + def __init__(self, z_dim=48, dim=160): + super(WanVideoVAE, self).__init__() + + mean = [ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667 + ] + std = [ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) + self.upsampling_factor = 16 + self.z_dim = z_dim diff --git a/DiffSynth-Studio/diffsynth/models/wav2vec.py b/DiffSynth-Studio/diffsynth/models/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..8807302d815a917123b794a597fc5fe84d3394fc --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/wav2vec.py @@ -0,0 +1,191 @@ +import math +import numpy as np +import torch +import torch.nn.functional as F + + +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) + + +class WanS2VAudioEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + from transformers import Wav2Vec2ForCTC, Wav2Vec2Config + config = { + "_name_or_path": "facebook/wav2vec2-large-xlsr-53", + "activation_dropout": 0.05, + "apply_spec_augment": True, + "architectures": ["Wav2Vec2ForCTC"], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": True, + "conv_dim": [512, 512, 512, 512, 512, 512, 512], + "conv_kernel": [10, 3, 3, 3, 3, 2, 2], + "conv_stride": [5, 2, 2, 2, 2, 2, 2], + "ctc_loss_reduction": "mean", + "ctc_zero_infinity": True, + "do_stable_layer_norm": True, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.05, + "final_dropout": 0.0, + "hidden_act": "gelu", + "hidden_dropout": 0.05, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.05, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.05, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.7.0.dev0", + "vocab_size": 33 + } + self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) + self.video_rate = 30 + + def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'): + input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device) + + # retrieve logits & take argmax + res = self.model(input_values, output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + return feat + + def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0 + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'): + audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device) + audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype) + audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)] + return audio_embeds diff --git a/DiffSynth-Studio/diffsynth/models/z_image_controlnet.py b/DiffSynth-Studio/diffsynth/models/z_image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5105534c736d1908fb5210eda0bfbd07c3b281df --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/z_image_controlnet.py @@ -0,0 +1,154 @@ +from .z_image_dit import ZImageTransformerBlock +from ..core.gradient import gradient_checkpoint_forward +from torch.nn.utils.rnn import pad_sequence +import torch +from torch import nn + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int = 1000, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + norm_eps: float = 1e-5, + qk_norm: bool = True, + modulation = True, + block_id = 0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + self.after_proj = nn.Linear(self.dim, self.dim) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class ZImageControlNet(torch.nn.Module): + def __init__( + self, + control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + control_in_dim=33, + dim=3840, + n_refiner_layers=2, + ): + super().__init__() + self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places]) + self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)}) + self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)]) + self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14} + + def forward_layers( + self, + x, + cap_feats, + control_context, + control_context_item_seqlens, + kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + bsz = len(control_context) + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + control_context_len = control_context_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]])) + c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for layer in self.control_layers: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + return hints + + def forward_refiner( + self, + dit, + x, + cap_feats, + control_context, + kwargs, + t=None, + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + control_context_size, + control_context_pos_ids, + control_context_inner_pad_mask, + ) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + control_context_item_seqlens = [len(_) for _ in control_context] + assert all(_ % 2 == 0 for _ in control_context_item_seqlens) + control_context_max_item_seqlen = max(control_context_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device) + control_context = list(control_context.split(control_context_item_seqlens, dim=0)) + control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0) + control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(control_context_item_seqlens): + control_context_attn_mask[i, :seq_len] = 1 + c = control_context + + # arguments + new_kwargs = dict( + x=x, + attn_mask=control_context_attn_mask, + freqs_cis=control_context_freqs_cis, + adaln_input=adaln_input, + ) + new_kwargs.update(kwargs) + + for layer in self.control_noise_refiner: + c = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + c=c, **new_kwargs + ) + + hints = torch.unbind(c)[:-1] + control_context = torch.unbind(c)[-1] + + return hints, control_context, control_context_item_seqlens \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/models/z_image_dit.py b/DiffSynth-Studio/diffsynth/models/z_image_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..70574f01a66b053b4da76761dd8a0952a208355e --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/z_image_dit.py @@ -0,0 +1,1152 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from .general_modules import RMSNorm +from ..core.attention import attention_forward +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type +from ..core.gradient import gradient_checkpoint_forward + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + mid_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + mid_size, + out_size, + bias=True, + ), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast(get_device_type(), enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(torch.bfloat16)) + return t_emb + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)]) + + self.norm_q = RMSNorm(head_dim, eps=1e-5) + self.norm_k = RMSNorm(head_dim, eps=1e-5) + + def forward(self, hidden_states, freqs_cis, attention_mask): + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = query.unflatten(-1, (self.num_heads, -1)) + key = key.unflatten(-1, (self.num_heads, -1)) + value = value.unflatten(-1, (self.num_heads, -1)) + + # Apply Norms + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast(get_device_type(), enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # Compute joint attention + hidden_states = attention_forward( + query, + key, + value, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + attn_mask=attention_mask, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = self.to_out[0](hidden_states) + if len(self.to_out) > 1: # dropout + output = self.to_out[1](output) + + return output + + +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + q_dim=dim, + num_heads=n_heads, + head_dim=dim // n_heads, + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, + ): + if self.modulation: + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + if IS_NPU_AVAILABLE: + result.append(torch.index_select(self.freqs_cis[i], 0, index)) + else: + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageDiT(nn.Module): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock"] + + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + siglip_feat_dim=None, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + # Optional SigLIP components (for Omni variant) + self.siglip_feat_dim = siglip_feat_dim + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size = 2, + f_patch_size = 1, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return all_image_out, all_cap_feats_out, { + "x_size": all_image_size, + "x_pos_ids": all_image_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "x_pad_mask": all_image_pad_mask, + "cap_pad_mask": all_cap_pad_mask + } + # ( + # all_img_out, + # all_cap_out, + # all_img_size, + # all_img_pos_ids, + # all_cap_pos_ids, + # all_img_pad_mask, + # all_cap_pad_mask, + # ) + + def patchify_controlnet( + self, + all_image: List[torch.Tensor], + patch_size: int = 2, + f_patch_size: int = 1, + cap_padding_len: int = None, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) + + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device) + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int = 2, + f_patch_size: int = 1, + images_noise_mask: List[List[int]] = None, + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] + + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] + + return ( + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_sig_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, + ) + return all_x_out, all_cap_out, all_sig_out, { + "x_size": x_size, + "x_pos_ids": all_x_pos_ids, + "cap_pos_ids": all_cap_pos_ids, + "sig_pos_ids": all_sig_pos_ids, + "x_pad_mask": all_x_pad_mask, + "cap_pad_mask": all_cap_pad_mask, + "sig_pad_mask": all_sig_pad_mask, + "x_pos_offsets": all_x_pos_offsets, + "x_noise_mask": all_x_noise_mask, + "cap_noise_mask": all_cap_noise_mask, + "sig_noise_mask": all_sig_noise_mask, + } + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + siglip_feats = None, + image_noise_mask = None, + patch_size=2, + f_patch_size=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device + + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None + + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None + + # x embed & refine + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) + + for layer in self.noise_refiner: + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean, + ) + + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device + ) + + for layer in self.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=cap_mask, + freqs_cis=cap_freqs, + ) + + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) + + for layer in self.siglip_refiner: + siglip_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs, + ) + + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) + + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean + ) + + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) + + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) + + return x diff --git a/DiffSynth-Studio/diffsynth/models/z_image_image2lora.py b/DiffSynth-Studio/diffsynth/models/z_image_image2lora.py new file mode 100644 index 0000000000000000000000000000000000000000..757f3f6778bb24187333451bf30d6773b867ad77 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/z_image_image2lora.py @@ -0,0 +1,189 @@ +import torch +from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"): + super().__init__() + self.prefix = prefix + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class ZImageImage2LoRAComponent(torch.nn.Module): + def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + +class ZImageImage2LoRAModel(torch.nn.Module): + def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + lora_patterns = [ + [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ], + [ + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ], + ] + config = { + "lora_patterns": lora_patterns, + "use_residual": use_residual, + "compress_dim": compress_dim, + "rank": rank, + "residual_length": residual_length, + "residual_mid_dim": residual_mid_dim, + } + self.layers_lora = ZImageImage2LoRAComponent( + prefix="layers", + num_blocks=30, + **config, + ) + self.context_refiner_lora = ZImageImage2LoRAComponent( + prefix="context_refiner", + num_blocks=2, + **config, + ) + self.noise_refiner_lora = ZImageImage2LoRAComponent( + prefix="noise_refiner", + num_blocks=2, + **config, + ) + + def forward(self, x, residual=None): + lora = {} + lora.update(self.layers_lora(x, residual=residual)) + lora.update(self.context_refiner_lora(x, residual=residual)) + lora.update(self.noise_refiner_lora(x, residual=residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) + + +class ImageEmb2LoRAWeightCompressed(torch.nn.Module): + def __init__(self, in_dim, out_dim, emb_dim, rank): + super().__init__() + self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim))) + self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank))) + self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True) + self.rank = rank + + def forward(self, x): + x = self.proj(x).view(self.rank, self.rank) + lora_a = x @ self.lora_a + lora_b = self.lora_b + return lora_a, lora_b + + +class ZImageImage2LoRAModelCompressed(torch.nn.Module): + def __init__(self, emb_dim=1536+4096, rank=32): + super().__init__() + target_layers = [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ] + self.lora_patterns = [ + { + "prefix": "layers", + "num_layers": 30, + "target_layers": target_layers, + }, + { + "prefix": "context_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + { + "prefix": "noise_refiner", + "num_layers": 2, + "target_layers": target_layers, + }, + ] + module_dict = {} + for lora_pattern in self.lora_patterns: + prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"] + for layer_id in range(num_layers): + for layer_name, in_dim, out_dim in target_layers: + name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___") + model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank) + module_dict[name] = model + self.module_dict = torch.nn.ModuleDict(module_dict) + + def forward(self, x, residual=None): + lora = {} + for name, module in self.module_dict.items(): + name = name.replace("___", ".") + name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight" + lora_a, lora_b = module(x) + lora[name_a] = lora_a + lora[name_b] = lora_b + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if "lora_b" in name: + state_dict[name] = state_dict[name] * 0 + elif "lora_a" in name: + state_dict[name] = state_dict[name] * 0.2 + elif "proj.weight" in name: + print(name) + state_dict[name] = state_dict[name] * 0.2 + self.load_state_dict(state_dict) diff --git a/DiffSynth-Studio/diffsynth/models/z_image_text_encoder.py b/DiffSynth-Studio/diffsynth/models/z_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6271dca625c56ac5cd05d12cb6628febc83da6 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/models/z_image_text_encoder.py @@ -0,0 +1,74 @@ +from transformers import Qwen3Model, Qwen3Config +import torch + + +class ZImageTextEncoder(torch.nn.Module): + def __init__(self, model_size="4B"): + super().__init__() + config_dict = { + "4B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), + "8B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": False, + "transformers_version": "4.56.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }) + } + config = config_dict[model_size] + self.model = Qwen3Model(config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/DiffSynth-Studio/diffsynth/pipelines/__pycache__/wan_video.cpython-39.pyc b/DiffSynth-Studio/diffsynth/pipelines/__pycache__/wan_video.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b716c37994ae81d8a1ee82febd3bd5ebcfff33 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/pipelines/__pycache__/wan_video.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/pipelines/flux2_image.py b/DiffSynth-Studio/diffsynth/pipelines/flux2_image.py new file mode 100644 index 0000000000000000000000000000000000000000..d5dc35bd4f139c828dbd60bf5ac9aa8f34a46932 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/pipelines/flux2_image.py @@ -0,0 +1,591 @@ +import torch, math, torchvision +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from typing import Union, List, Optional, Tuple + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput + +from transformers import AutoProcessor, AutoTokenizer +from ..models.flux2_text_encoder import Flux2TextEncoder +from ..models.flux2_dit import Flux2DiT +from ..models.flux2_vae import Flux2VAE +from ..models.z_image_text_encoder import ZImageTextEncoder + + +class Flux2ImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: Flux2TextEncoder = None + self.text_encoder_qwen3: ZImageTextEncoder = None + self.dit: Flux2DiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + Flux2Unit_ShapeChecker(), + Flux2Unit_PromptEmbedder(), + Flux2Unit_Qwen3PromptEmbedder(), + Flux2Unit_NoiseInitializer(), + Flux2Unit_InputImageEmbedder(), + Flux2Unit_EditImageEmbedder(), + Flux2Unit_ImageIDs(), + ] + self.model_fn = model_fn_flux2 + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = Flux2ImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder") + pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("flux2_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Edit + edit_image: Union[Image.Image, List[Image.Image]] = None, + edit_image_auto_resize: bool = True, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, + "input_image": input_image, "denoising_strength": denoising_strength, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + latents = rearrange(inputs_shared["latents"], "B (H W) C -> B C H W", H=inputs_shared["height"]//16, W=inputs_shared["width"]//16) + image = self.vae.decode(latents) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class Flux2Unit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: Flux2ImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class Flux2Unit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + + def format_text_input(self, prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + def get_mistral_3_small_prompt_embeds( + self, + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = self.format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder, + tokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_mistral_3_small_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder) + if pipe.text_encoder_qwen3 is not None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + +class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder_qwen3",) + ) + self.hidden_states_layers = (9, 18, 27) # Qwen3 layers + + def get_qwen3_prompt_embeds( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + with torch.inference_mode(): + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + return prompt_embeds + + def prepare_text_ids( + self, + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + text_encoder: ZImageTextEncoder, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + dtype = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self.get_qwen3_prompt_embeds( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + dtype=dtype, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self.prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def process(self, pipe: Flux2ImagePipeline, prompt): + # Check if Qwen3 text encoder is available + if pipe.text_encoder_qwen3 is None: + return {} + + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds, text_ids = self.encode_prompt( + pipe.text_encoder_qwen3, pipe.tokenizer, prompt, + dtype=pipe.torch_dtype, device=pipe.device, + ) + return {"prompt_embeds": prompt_embeds, "text_ids": text_ids} + + +class Flux2Unit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1) + return {"noise": noise} + + +class Flux2Unit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: Flux2ImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + input_latents = rearrange(input_latents, "B C H W -> B (H W) C") + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class Flux2Unit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_latents", "edit_image_ids"), + onload_model_names=("vae",) + ) + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return self.crop_and_resize(edit_image, calculated_height, calculated_width) + + def process_image_ids(self, image_latents, scale=10): + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_image, Image.Image): + edit_image = [edit_image] + resized_edit_image, edit_latents = [], [] + for image in edit_image: + # Preprocess + if edit_image_auto_resize is None or edit_image_auto_resize: + image = self.edit_image_auto_resize(image) + resized_edit_image.append(image) + # Encode + image = pipe.preprocess_image(image) + latents = pipe.vae.encode(image) + edit_latents.append(latents) + edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device) + edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1) + return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids} + + +class Flux2Unit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("image_ids",), + ) + + def prepare_latent_ids(self, height, width): + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(1, -1, -1) + + return latent_ids + + def process(self, pipe: Flux2ImagePipeline, height, width): + image_ids = self.prepare_latent_ids(height // 16, width // 16).to(pipe.device) + return {"image_ids": image_ids} + + +def model_fn_flux2( + dit: Flux2DiT, + latents=None, + timestep=None, + embedded_guidance=None, + prompt_embeds=None, + text_ids=None, + image_ids=None, + edit_latents=None, + edit_image_ids=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + image_seq_len = latents.shape[1] + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) + embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) + model_output = dit( + hidden_states=latents, + timestep=timestep / 1000, + guidance=embedded_guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=image_ids, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + model_output = model_output[:, :image_seq_len] + return model_output diff --git a/DiffSynth-Studio/diffsynth/pipelines/flux_image.py b/DiffSynth-Studio/diffsynth/pipelines/flux_image.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc53e505c2364b52df4627589dd22cc8e4801e6 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/pipelines/flux_image.py @@ -0,0 +1,1206 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.flux import FluxLoRALoader + +from ..models.flux_dit import FluxDiT +from ..models.flux_text_encoder_clip import FluxTextEncoderClip +from ..models.flux_text_encoder_t5 import FluxTextEncoderT5 +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.flux_value_control import MultiValueEncoder +from ..models.step1x_text_encoder import Step1xEditEmbedder +from ..core.vram.layers import AutoWrappedLinear + +class MultiControlNet(torch.nn.Module): + def __init__(self, models: list[torch.nn.Module]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + + def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs): + model = self.models[controlnet_input.controlnet_id] + res_stack, single_res_stack = model( + controlnet_conditioning=conditioning, + processor_id=controlnet_input.processor_id, + **kwargs + ) + res_stack = [res * controlnet_input.scale for res in res_stack] + single_res_stack = [res * controlnet_input.scale for res in single_res_stack] + return res_stack, single_res_stack + + def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs): + res_stack, single_res_stack = None, None + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start or progress < controlnet_input.end: + continue + res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs) + if res_stack is None: + res_stack = res_stack_ + single_res_stack = single_res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] + return res_stack, single_res_stack + + +class FluxImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.1") + self.tokenizer_1: CLIPTokenizer = None + self.tokenizer_2: T5TokenizerFast = None + self.text_encoder_1: FluxTextEncoderClip = None + self.text_encoder_2: FluxTextEncoderT5 = None + self.dit: FluxDiT = None + self.vae_decoder: FluxVAEDecoder = None + self.vae_encoder: FluxVAEEncoder = None + self.controlnet = None + self.ipadapter = None + self.ipadapter_image_encoder = None + self.qwenvl = None + self.step1x_connector = None + self.nexus_gen = None + self.nexus_gen_generation_adapter = None + self.nexus_gen_editing_adapter = None + self.value_controller = None + self.infinityou_processor = None + self.image_proj_model = None + self.lora_patcher = None + self.lora_encoder = None + self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher") + self.units = [ + FluxImageUnit_ShapeChecker(), + FluxImageUnit_NoiseInitializer(), + FluxImageUnit_PromptEmbedder(), + FluxImageUnit_InputImageEmbedder(), + FluxImageUnit_ImageIDs(), + FluxImageUnit_EmbeddedGuidanceEmbedder(), + FluxImageUnit_Kontext(), + FluxImageUnit_InfiniteYou(), + FluxImageUnit_ControlNet(), + FluxImageUnit_IPAdapter(), + FluxImageUnit_EntityControl(), + FluxImageUnit_NexusGen(), + FluxImageUnit_TeaCache(), + FluxImageUnit_Flex(), + FluxImageUnit_Step1x(), + FluxImageUnit_ValueControl(), + FluxImageUnit_LoRAEncode(), + ] + self.model_fn = model_fn_flux_image + self.lora_loader = FluxLoRALoader + + def enable_lora_merger(self): + if not (hasattr(self.dit, "vram_management_enabled") and getattr(self.dit, "vram_management_enabled")): + raise ValueError("DiT VRAM management is not enabled.") + if self.lora_patcher is not None: + for name, module in self.dit.named_modules(): + if isinstance(module, AutoWrappedLinear): + merger_name = name.replace(".", "___") + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"), + tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"), + nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"), + step1x_processor_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern=""), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder_1 = model_pool.fetch_model("flux_text_encoder_clip") + pipe.text_encoder_2 = model_pool.fetch_model("flux_text_encoder_t5") + pipe.dit = model_pool.fetch_model("flux_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + if tokenizer_1_config is not None: + tokenizer_1_config.download_if_necessary() + pipe.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_config.path) + if tokenizer_2_config is not None: + tokenizer_2_config.download_if_necessary() + pipe.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_config.path) + + value_controllers = model_pool.fetch_model("flux_value_controller") + if value_controllers is not None: + pipe.value_controller = MultiValueEncoder(value_controllers) + if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"): + pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled + controlnets = model_pool.fetch_model("flux_controlnet") + if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) + pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") + pipe.ipadapter_image_encoder = model_pool.fetch_model("siglip_vision_model") + qwenvl = model_pool.fetch_model("qwen_image_text_encoder") + if qwenvl is not None: + from transformers import AutoProcessor + step1x_processor_config.download_if_necessary() + processor = AutoProcessor.from_pretrained(step1x_processor_config.path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28) + pipe.qwenvl = Step1xEditEmbedder(qwenvl, processor) + pipe.step1x_connector = model_pool.fetch_model("step1x_connector") + pipe.image_proj_model = model_pool.fetch_model("infiniteyou_image_projector") + if pipe.image_proj_model is not None: + pipe.infinityou_processor = InfinitYou(device=device) + pipe.lora_patcher = model_pool.fetch_model("flux_lora_patcher") + pipe.lora_encoder = model_pool.fetch_model("flux_lora_encoder") + pipe.nexus_gen = model_pool.fetch_model("nexus_gen_llm") + pipe.nexus_gen_generation_adapter = model_pool.fetch_model("nexus_gen_generation_adapter") + pipe.nexus_gen_editing_adapter = model_pool.fetch_model("nexus_gen_editing_adapter") + if pipe.nexus_gen is not None: + nexus_gen_processor_config.download_if_necessary() + pipe.nexus_gen.load_processor(nexus_gen_processor_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + embedded_guidance: float = 3.5, + t5_sequence_length: int = 512, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Scheduler + sigma_shift: float = None, + # Steps + num_inference_steps: int = 30, + # local prompts + multidiffusion_prompts=(), + multidiffusion_masks=(), + multidiffusion_scales=(), + # Kontext + kontext_images: Union[list[Image.Image], Image.Image] = None, + # ControlNet + controlnet_inputs: list[ControlNetInput] = None, + # IP-Adapter + ipadapter_images: Union[list[Image.Image], Image.Image] = None, + ipadapter_scale: float = 1.0, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + eligen_enable_inpaint: bool = False, + # InfiniteYou + infinityou_id_image: Image.Image = None, + infinityou_guidance: float = 1.0, + # Flex + flex_inpaint_image: Image.Image = None, + flex_inpaint_mask: Image.Image = None, + flex_control_image: Image.Image = None, + flex_control_strength: float = 0.5, + flex_control_stop: float = 0.5, + # Value Controller + value_controller_inputs: Union[list[float], float] = None, + # Step1x + step1x_reference_image: Image.Image = None, + # NexusGen + nexus_gen_reference_image: Image.Image = None, + # LoRA Encoder + lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None, + lora_encoder_scale: float = 1.0, + # TeaCache + tea_cache_l1_thresh: float = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, + "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, + "kontext_images": kontext_images, + "controlnet_inputs": controlnet_inputs, + "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, + "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance, + "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, + "value_controller_inputs": value_controller_inputs, + "step1x_reference_image": step1x_reference_image, + "nexus_gen_reference_image": nexus_gen_reference_image, + "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale, + "tea_cache_l1_thresh": tea_cache_l1_thresh, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "progress_bar_cmd": progress_bar_cmd, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class FluxImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width"), output_params=("height", "width")) + + def process(self, pipe: FluxImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class FluxImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",)) + + def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) + return {"noise": noise} + + + +class FluxImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae_encoder']) + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": None} + + + +class FluxImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + input_params=("t5_sequence_length",), + output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device=get_device_type(), + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict: + if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None: + prompt_emb, pooled_prompt_emb, text_ids = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length, + ) + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} + else: + return {} + + +class FluxImageUnit_ImageIDs(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents",), output_params=("image_ids",)) + + def process(self, pipe: FluxImagePipeline, latents): + latent_image_ids = pipe.dit.prepare_image_ids(latents) + return {"image_ids": latent_image_ids} + + + +class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): + def __init__(self): + super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",)) + + def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): + guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + return {"guidance": guidance} + + + +class FluxImageUnit_Kontext(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("kontext_images", "tiled", "tile_size", "tile_stride"), + output_params=("kontext_latents", "kontext_image_ids"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): + if kontext_images is None: + return {} + if not isinstance(kontext_images, list): + kontext_images = [kontext_images] + + kontext_latents = [] + kontext_image_ids = [] + for kontext_image in kontext_images: + kontext_image = pipe.preprocess_image(kontext_image) + kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image_ids = pipe.dit.prepare_image_ids(kontext_latent) + image_ids[..., 0] = 1 + kontext_image_ids.append(image_ids) + kontext_latent = pipe.dit.patchify(kontext_latent) + kontext_latents.append(kontext_latent) + kontext_latents = torch.concat(kontext_latents, dim=1) + kontext_image_ids = torch.concat(kontext_image_ids, dim=-2) + return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids} + + + +class FluxImageUnit_ControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("controlnet_conditionings",), + onload_model_names=("vae_encoder",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if controlnet_inputs is None: + return {} + pipe.load_models_to_device(['vae_encoder']) + conditionings = [] + for controlnet_input in controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + return {"controlnet_conditionings": conditionings} + + + +class FluxImageUnit_IPAdapter(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("ipadapter_images", "ipadapter_scale"), + output_params=("ipadapter_kwargs_list",), + onload_model_names=("ipadapter_image_encoder", "ipadapter") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0) + if ipadapter_images is None: + return inputs_shared, inputs_posi, inputs_nega + if not isinstance(ipadapter_images, list): + ipadapter_images = [ipadapter_images] + + pipe.load_models_to_device(self.onload_model_names) + images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images] + images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images] + ipadapter_images = torch.cat(images, dim=0) + ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output + + inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))}) + return inputs_shared, inputs_posi, inputs_nega + + + +class FluxImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks"), + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids.to(device) + pooled_prompt_emb, _ = text_encoder(input_ids) + return pooled_prompt_emb + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + return prompt_emb + + def encode_prompt( + self, + tokenizer_1, + tokenizer_2, + text_encoder_1, + text_encoder_2, + prompt, + positive=True, + device=get_device_type(), + t5_sequence_length=512, + ): + pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device) + prompt_emb = self.encode_prompt_using_t5(prompt, text_encoder_2, tokenizer_2, t5_sequence_length, device) + text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype) + return prompt_emb, pooled_prompt_emb, text_ids + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + + prompt_emb, _, _ = self.encode_prompt( + tokenizer_1=pipe.tokenizer_1, tokenizer_2=pipe.tokenizer_2, + text_encoder_1=pipe.text_encoder_1, text_encoder_2=pipe.text_encoder_2, + prompt=entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length, + ) + return prompt_emb.unsqueeze(0), entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_NexusGen(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("nexus_gen_reference_image", "prompt", "latents"), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.nexus_gen is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + if inputs_shared.get("nexus_gen_reference_image", None) is None: + assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set." + embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0) + inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed) + inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype) + else: + assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set." + embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"]) + embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long) + ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long) + + inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid) + inputs_posi["text_ids"] = self.get_editing_text_ids( + inputs_shared["latents"], + embeds_grid[0][1].item(), embeds_grid[0][2].item(), + ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(), + ) + return inputs_shared, inputs_posi, inputs_nega + + + def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width): + # prepare text ids for target and reference embeddings + batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width + embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype) + + batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width + ref_embed_ids = torch.zeros(height // 2, width // 2, 3) + scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width + ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0 + ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height + ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width + ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3) + ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype) + + text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1) + return text_ids + + +class FluxImageUnit_Step1x(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("step1x_reference_image", "prompt", "negative_prompt"), + output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"), + onload_model_names=("qwenvl","vae_encoder") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): + image = inputs_shared.get("step1x_reference_image",None) + if image is None: + return inputs_shared, inputs_posi, inputs_nega + else: + pipe.load_models_to_device(self.onload_model_names) + prompt = inputs_posi["prompt"] + nega_prompt = inputs_nega["negative_prompt"] + captions = [prompt, nega_prompt] + ref_images = [image, image] + embs, masks = pipe.qwenvl(captions, ref_images) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image) + inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}) + if inputs_shared.get("cfg_scale", 1) != 1: + inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",)) + + def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): + if tea_cache_l1_thresh is None: + return {} + else: + return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)} + +class FluxImageUnit_Flex(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), + output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride): + if pipe.dit.input_dim == 196: + if flex_control_stop is None: + flex_control_stop = 1 + pipe.load_models_to_device(self.onload_model_names) + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))] + return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} + else: + return {} + + + +class FluxImageUnit_InfiniteYou(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("infinityou_id_image", "infinityou_guidance"), + output_params=("id_emb", "infinityou_guidance"), + onload_model_names=("infinityou_processor",) + ) + + def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance): + pipe.load_models_to_device("infinityou_processor") + if infinityou_id_image is not None: + return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device) + else: + return {} + + + +class FluxImageUnit_ValueControl(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params=("value_controller_inputs",), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("value_controller",) + ) + + def add_to_text_embedding(self, prompt_emb, text_ids, value_emb): + prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) + extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs): + if value_controller_inputs is None: + return {} + if not isinstance(value_controller_inputs, list): + value_controller_inputs = [value_controller_inputs] + value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device) + pipe.load_models_to_device(["value_controller"]) + value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype) + value_emb = value_emb.unsqueeze(0) + prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb) + return {"prompt_emb": prompt_emb, "text_ids": text_ids} + + + +class InfinitYou(torch.nn.Module): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__() + from facexlib.recognition import init_recognition_model + from insightface.app import FaceAnalysis + self.device = device + self.torch_dtype = torch_dtype + insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface' + self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_640.prepare(ctx_id=0, det_size=(640, 640)) + self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_320.prepare(ctx_id=0, det_size=(320, 320)) + self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + self.app_160.prepare(ctx_id=0, det_size=(160, 160)) + self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype) + + def _detect_face(self, id_image_cv2): + face_info = self.app_640.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_320.get(id_image_cv2) + if len(face_info) > 0: + return face_info + face_info = self.app_160.get(id_image_cv2) + return face_info + + def extract_arcface_bgr_embedding(self, in_image, landmark, device): + from insightface.utils import face_align + arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112) + arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255. + arc_face_image = 2 * arc_face_image - 1 + arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype) + face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized + return face_emb + + def prepare_infinite_you(self, model, id_image, infinityou_guidance, device): + import cv2 + if id_image is None: + return {'id_emb': None} + id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR) + face_info = self._detect_face(id_image_cv2) + if len(face_info) == 0: + raise ValueError('No face detected in the input ID image') + landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face + id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device) + id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype)) + infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype) + return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance} + + + +class FluxImageUnit_LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("lora_encoder_inputs", "lora_encoder_scale"), + output_params=("prompt_emb", "text_ids"), + onload_model_names=("lora_encoder",) + ) + + def parse_lora_encoder_inputs(self, lora_encoder_inputs): + if not isinstance(lora_encoder_inputs, list): + lora_encoder_inputs = [lora_encoder_inputs] + lora_configs = [] + for lora_encoder_input in lora_encoder_inputs: + if isinstance(lora_encoder_input, str): + lora_encoder_input = ModelConfig(path=lora_encoder_input) + lora_encoder_input.download_if_necessary() + lora_configs.append(lora_encoder_input) + return lora_configs + + def load_lora(self, lora_config, dtype, device): + loader = FluxLoRALoader(torch_dtype=dtype, device=device) + lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device) + lora = loader.convert_state_dict(lora) + return lora + + def lora_embedding(self, pipe, lora_encoder_inputs): + lora_emb = [] + for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs): + lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device) + lora_emb.append(pipe.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb): + prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1) + extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("lora_encoder_inputs", None) is None: + return inputs_shared, inputs_posi, inputs_nega + + # Encode + pipe.load_models_to_device(["lora_encoder"]) + lora_encoder_inputs = inputs_shared["lora_encoder_inputs"] + lora_emb = self.lora_embedding(pipe, lora_encoder_inputs) + + # Scale + lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None) + if lora_encoder_scale is not None: + lora_emb = lora_emb * lora_encoder_scale + + # Add to prompt embedding + inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding( + inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb) + return inputs_shared, inputs_posi, inputs_nega + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: FluxDiT, hidden_states, conditioning): + inp = hidden_states.clone() + temb_ = conditioning.clone() + modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = hidden_states.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +class FastTileWorker: + def __init__(self): + pass + + + def build_mask(self, data, is_bound): + _, _, H, W = data.shape + h = repeat(torch.arange(H), "H -> H W", H=H, W=W) + w = repeat(torch.arange(W), "W -> H W", H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else h + 1, + pad if is_bound[1] else H - h, + pad if is_bound[2] else w + 1, + pad if is_bound[3] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "H W -> 1 H W") + return mask + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + B, C, H, W = model_input.shape + border_width = int(tile_stride*0.5) if border_width is None else border_width + weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device) + values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = H - tile_size, H + if w_ > W: w, w_ = W - tile_size, W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + # Forward + hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device) + + mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W)) + values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask + weight[:, :, hl:hr, wl:wr] += mask + values /= weight + return values + + +def model_fn_flux_image( + dit: FluxDiT, + controlnet=None, + step1x_connector=None, + latents=None, + timestep=None, + prompt_emb=None, + pooled_prompt_emb=None, + guidance=None, + text_ids=None, + image_ids=None, + kontext_latents=None, + kontext_image_ids=None, + controlnet_inputs=None, + controlnet_conditionings=None, + tiled=False, + tile_size=128, + tile_stride=64, + entity_prompt_emb=None, + entity_masks=None, + ipadapter_kwargs_list={}, + id_emb=None, + infinityou_guidance=None, + flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, + step1x_llm_embedding=None, + step1x_mask=None, + step1x_reference_latents=None, + tea_cache: TeaCache = None, + progress_id=0, + num_inference_steps=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs +): + if tiled: + def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None + return model_fn_flux_image( + dit=dit, + controlnet=controlnet, + latents=latents[:, :, hl: hr, wl: wr], + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=pooled_prompt_emb, + guidance=guidance, + text_ids=text_ids, + image_ids=None, + controlnet_inputs=controlnet_inputs, + controlnet_conditionings=tiled_controlnet_conditionings, + tiled=False, + **kwargs + ) + return FastTileWorker().tiled_forward( + flux_forward_fn, + latents, + tile_size=tile_size, + tile_stride=tile_stride, + tile_device=latents.device, + tile_dtype=latents.dtype + ) + + hidden_states = latents + + # ControlNet + if controlnet is not None and controlnet_conditionings is not None: + controlnet_extra_kwargs = { + "hidden_states": hidden_states, + "timestep": timestep, + "prompt_emb": prompt_emb, + "pooled_prompt_emb": pooled_prompt_emb, + "guidance": guidance, + "text_ids": text_ids, + "image_ids": image_ids, + "controlnet_inputs": controlnet_inputs, + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride, + "progress_id": progress_id, + "num_inference_steps": num_inference_steps, + } + if id_emb is not None: + controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype) + controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance}) + controlnet_res_stack, controlnet_single_res_stack = controlnet( + controlnet_conditionings, **controlnet_extra_kwargs + ) + + # Flex + if flex_condition is not None: + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) + + # Step1x + if step1x_llm_embedding is not None: + prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask) + text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device) + + if image_ids is None: + image_ids = dit.prepare_image_ids(hidden_states) + + conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb) + if dit.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) + + height, width = hidden_states.shape[-2:] + hidden_states = dit.patchify(hidden_states) + + # Kontext + if kontext_latents is not None: + image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, kontext_latents], dim=1) + + # Step1x + if step1x_reference_latents is not None: + step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) + step1x_reference_latents = dit.patchify(step1x_reference_latents) + image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1) + + hidden_states = dit.x_embedder(hidden_states) + + # EliGen + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1]) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) + else: + tea_cache_update = False + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None: + if kontext_latents is None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + else: + hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + ) + # ControlNet + if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None: + if kontext_latents is None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + else: + hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + if tea_cache is not None: + tea_cache.store(hidden_states) + + hidden_states = dit.final_norm_out(hidden_states, conditioning) + hidden_states = dit.final_proj_out(hidden_states) + + # Step1x + if step1x_reference_latents is not None: + hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + + # Kontext + if kontext_latents is not None: + hidden_states = hidden_states[:, :-kontext_latents.shape[1]] + + hidden_states = dit.unpatchify(hidden_states, height, width) + + return hidden_states diff --git a/DiffSynth-Studio/diffsynth/pipelines/qwen_image.py b/DiffSynth-Studio/diffsynth/pipelines/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..75cfbee77a9ae23297fbd16c7752351193c00f78 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/pipelines/qwen_image.py @@ -0,0 +1,815 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from math import prod + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora.merge import merge_lora + +from ..models.qwen_image_dit import QwenImageDiT +from ..models.qwen_image_text_encoder import QwenImageTextEncoder +from ..models.qwen_image_vae import QwenImageVAE +from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel + + +class QwenImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + from transformers import Qwen2Tokenizer, Qwen2VLProcessor + + self.scheduler = FlowMatchScheduler("Qwen-Image") + self.text_encoder: QwenImageTextEncoder = None + self.dit: QwenImageDiT = None + self.vae: QwenImageVAE = None + self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None + self.tokenizer: Qwen2Tokenizer = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: QwenImageImage2LoRAModel = None + self.image2lora_coarse: QwenImageImage2LoRAModel = None + self.image2lora_fine: QwenImageImage2LoRAModel = None + self.processor: Qwen2VLProcessor = None + self.in_iteration_models = ("dit", "blockwise_controlnet") + self.units = [ + QwenImageUnit_ShapeChecker(), + QwenImageUnit_NoiseInitializer(), + QwenImageUnit_InputImageEmbedder(), + QwenImageUnit_Inpaint(), + QwenImageUnit_EditImageEmbedder(), + QwenImageUnit_LayerInputImageEmbedder(), + QwenImageUnit_ContextImageEmbedder(), + QwenImageUnit_PromptEmbedder(), + QwenImageUnit_EntityControl(), + QwenImageUnit_BlockwiseControlNet(), + ] + self.model_fn = model_fn_qwen_image + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + processor_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("qwen_image_text_encoder") + pipe.dit = model_pool.fetch_model("qwen_image_dit") + pipe.vae = model_pool.fetch_model("qwen_image_vae") + pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model("qwen_image_blockwise_controlnet", index="all")) + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + from transformers import Qwen2Tokenizer + pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path) + if processor_config is not None: + processor_config.download_if_necessary() + from transformers import Qwen2VLProcessor + pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path) + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("qwen_image_image2lora_style") + pipe.image2lora_coarse = model_pool.fetch_model("qwen_image_image2lora_coarse") + pipe.image2lora_fine = model_pool.fetch_model("qwen_image_image2lora_fine") + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 4.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Inpaint + inpaint_mask: Image.Image = None, + inpaint_blur_size: int = None, + inpaint_blur_sigma: float = None, + # Shape + height: int = 1328, + width: int = 1328, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + exponential_shift_mu: float = None, + # Blockwise ControlNet + blockwise_controlnet_inputs: list[ControlNetInput] = None, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, + # Qwen-Image-Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, + edit_rope_interpolation: bool = False, + # Qwen-Image-Edit-2511 + zero_cond_t: bool = False, + # Qwen-Image-Layered + layer_input_image: Image.Image = None, + layer_num: int = None, + # In-context control + context_image: Image.Image = None, + # Tile + tiled: bool = False, + tile_size: int = 128, + tile_stride: int = 64, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "blockwise_controlnet_inputs": blockwise_controlnet_inputs, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, + "context_image": context_image, + "zero_cond_t": zero_cond_t, + "layer_input_image": layer_input_image, + "layer_num": layer_num, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if layer_num is None: + image = self.vae_output_to_image(image) + else: + image = [self.vae_output_to_image(i, pattern="C H W") for i in image] + self.load_models_to_device([]) + + return image + + +class QwenImageBlockwiseMultiControlNet(torch.nn.Module): + def __init__(self, models: list[QwenImageBlockWiseControlNet]): + super().__init__() + if not isinstance(models, list): + models = [models] + self.models = torch.nn.ModuleList(models) + for model in models: + if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"): + self.vram_management_enabled = True + + def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs): + processed_conditionings = [] + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning) + processed_conditionings.append(model_output) + return processed_conditionings + + def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs): + res = 0 + for controlnet_input, conditioning in zip(controlnet_inputs, conditionings): + progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1) + if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4): + continue + model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id) + res = res + model_output * controlnet_input.scale + return res + + +class QwenImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: QwenImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + + +class QwenImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device", "layer_num"), + output_params=("noise",), + ) + + def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num): + if layer_num is None: + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + else: + noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + + +class QwenImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + if isinstance(input_image, list): + input_latents = [] + for image in input_image: + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)) + input_latents = torch.concat(input_latents, dim=0) + else: + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"), + output_params=("layer_input_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride): + if layer_input_image is None: + return {} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"layer_input_latents": latents} + + +class QwenImageUnit_Inpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"), + output_params=("inpaint_mask",), + ) + + def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma): + if inpaint_mask is None: + return {} + inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1) + inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) + if inpaint_blur_size is not None and inpaint_blur_sigma is not None: + from torchvision.transforms import GaussianBlur + blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma) + inpaint_mask = blur(inpaint_mask) + return {"inpaint_mask": inpaint_mask} + + +class QwenImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + input_params=("edit_image",), + output_params=("prompt_emb", "prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def calculate_dimensions(self, target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + def resize_image(self, image, target_area=384*384): + width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1]) + return image.resize((width, height)) + + def encode_prompt(self, pipe: QwenImagePipeline, prompt): + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + if model_inputs.input_ids.shape[1] >= 1024: + print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image): + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))]) + txt = [template.format(base_img_prompt + e) for e in prompt] + edit_image = [self.resize_image(image) for image in edit_image] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + return split_hidden_states + + def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: + pipe.load_models_to_device(self.onload_model_names) + if pipe.text_encoder is not None: + prompt = [prompt] + if edit_image is None: + split_hidden_states = self.encode_prompt(pipe, prompt) + elif isinstance(edit_image, Image.Image): + split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image) + else: + split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + +class QwenImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"), + onload_model_names=("text_encoder",) + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict: + if pipe.text_encoder is not None: + prompt = [prompt] + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] + + split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + prompt_embs, prompt_emb_masks = [], [] + for entity_prompt in entity_prompts: + prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt) + prompt_embs.append(prompt_emb_dict['prompt_emb']) + prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask']) + return prompt_embs, prompt_emb_masks, entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi) + entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + + +class QwenImageUnit_BlockwiseControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("blockwise_controlnet_conditioning",), + onload_model_names=("vae",) + ) + + def apply_controlnet_mask_on_latents(self, pipe, latents, mask): + mask = (pipe.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + def apply_controlnet_mask_on_image(self, pipe, image, mask): + mask = mask.resize(image.size) + mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu() + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride): + if blockwise_controlnet_inputs is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + conditionings = [] + for controlnet_input in blockwise_controlnet_inputs: + image = controlnet_input.image + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask) + + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + if controlnet_input.inpaint_mask is not None: + image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask) + conditionings.append(image) + + return {"blockwise_controlnet_conditioning": conditionings} + + +class QwenImageUnit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"), + output_params=("edit_latents", "edit_image"), + onload_model_names=("vae",) + ) + + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return edit_image.resize((calculated_width, calculated_height)) + + + def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_image, Image.Image): + resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image + edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) + edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + else: + resized_edit_image, edit_latents = [], [] + for image in edit_image: + if edit_image_auto_resize: + image = self.edit_image_auto_resize(image) + resized_edit_image.append(image) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + edit_latents.append(latents) + return {"edit_latents": edit_latents, "edit_image": resized_edit_image} + + +class QwenImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder", "text_encoder"), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8) + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image): + prompt = [prompt] + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return prompt_embeds.view(1, -1) + + def encode_images_using_siglip2(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: QwenImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_qwenvl(self, pipe: QwenImagePipeline, images: list[Image.Image], highres=False): + pipe.load_models_to_device(["text_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) if highres else self.processor_lowres(image) + embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: QwenImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + residual = None + residual_highres = None + if pipe.image2lora_coarse is not None: + residual = self.encode_images_using_qwenvl(pipe, images, highres=False) + if pipe.image2lora_fine is not None: + residual_highres = self.encode_images_using_qwenvl(pipe, images, highres=True) + return x, residual, residual_highres + + def process(self, pipe: QwenImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x, residual, residual_highres = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres} + + +class QwenImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + output_params=("lora",), + onload_model_names=("image2lora_coarse", "image2lora_fine", "image2lora_style"), + ) + + def process(self, pipe: QwenImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + if pipe.image2lora_coarse is not None: + pipe.load_models_to_device(["image2lora_coarse"]) + for x, residual in zip(image2lora_x, image2lora_residual): + loras.append(pipe.image2lora_coarse(x=x, residual=residual)) + if pipe.image2lora_fine is not None: + pipe.load_models_to_device(["image2lora_fine"]) + for x, residual in zip(image2lora_x, image2lora_residual_highres): + loras.append(pipe.image2lora_fine(x=x, residual=residual)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +class QwenImageUnit_ContextImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("context_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride): + if context_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype) + context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"context_latents": context_latents} + + +def model_fn_qwen_image( + dit: QwenImageDiT = None, + blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None, + latents=None, + timestep=None, + prompt_emb=None, + prompt_emb_mask=None, + height=None, + width=None, + blockwise_controlnet_conditioning=None, + blockwise_controlnet_inputs=None, + progress_id=0, + num_inference_steps=1, + entity_prompt_emb=None, + entity_prompt_emb_mask=None, + entity_masks=None, + edit_latents=None, + layer_input_latents=None, + layer_num=None, + context_latents=None, + enable_fp8_attention=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + edit_rope_interpolation=False, + zero_cond_t=False, + **kwargs +): + if layer_num is None: + layer_num = 1 + img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] + else: + layer_num = layer_num + 1 + img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + timestep = timestep / 1000 + + image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num) + image_seq_len = image.shape[1] + + if context_latents is not None: + img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)] + context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2) + image = torch.cat([image, context_image], dim=1) + if edit_latents is not None: + edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents] + img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] + edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] + image = torch.cat([image] + edit_image, dim=1) + if layer_input_latents is not None: + layer_num = layer_num + 1 + img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)] + layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + image = torch.cat([image, layer_input_latents], dim=1) + + image = dit.img_in(image) + if zero_cond_t: + timestep = torch.cat([timestep, timestep * 0], dim=0) + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [img_shapes]], + device=timestep.device, + dtype=torch.int, + ) + else: + modulate_index = None + conditioning = dit.time_text_embed( + timestep, + image.dtype, + addition_t_cond=None if not dit.time_text_embed.use_additional_t_cond else torch.tensor([0]).to(device=image.device, dtype=torch.long) + ) + + if entity_prompt_emb is not None: + text, image_rotary_emb, attention_mask = dit.process_entity_masks( + latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, + entity_masks, height, width, image, img_shapes, + ) + else: + text = dit.txt_in(dit.txt_norm(prompt_emb)) + if edit_rope_interpolation: + image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device) + else: + image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + attention_mask = None + + if blockwise_controlnet_conditioning is not None: + blockwise_controlnet_conditioning = blockwise_controlnet.preprocess( + blockwise_controlnet_inputs, blockwise_controlnet_conditioning) + + for block_id, block in enumerate(dit.transformer_blocks): + text, image = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, + modulate_index=modulate_index, + ) + if blockwise_controlnet_conditioning is not None: + image_slice = image[:, :image_seq_len].clone() + controlnet_output = blockwise_controlnet.blockwise_forward( + image=image_slice, conditionings=blockwise_controlnet_conditioning, + controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id, + progress_id=progress_id, num_inference_steps=num_inference_steps, + ) + image[:, :image_seq_len] = image_slice + controlnet_output + + if zero_cond_t: + conditioning = conditioning.chunk(2, dim=0)[0] + image = dit.norm_out(image, conditioning) + image = dit.proj_out(image) + image = image[:, :image_seq_len] + + latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1) + return latents diff --git a/DiffSynth-Studio/diffsynth/pipelines/wan_video.py b/DiffSynth-Studio/diffsynth/pipelines/wan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb23e34a6fdf5cbe13669ea40efa70f243e6c96 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/pipelines/wan_video.py @@ -0,0 +1,1599 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import Wav2Vec2Processor +import json +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit +import safetensors.torch +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d +from ..models.wan_video_dit_s2v import rope_precompute +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.wan_video_mot import MotWanModel +from ..models.wav2vec import WanS2VAudioEncoder +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel + + + +def load_file(path): + state_dict = safetensors.torch.load_file(path, device="cpu") + return dict(state_dict) + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.audio_processor: Wav2Vec2Processor = None + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.dit2: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None + self.vap: MotWanModel = None + self.animate_adapter: WanAnimateAdapter = None + self.audio_encoder: WanS2VAudioEncoder = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_AnimateVideoSplit(), + WanVideoUnit_AnimatePoseLatents(), + WanVideoUnit_AnimateFacePixelValues(), + WanVideoUnit_AnimateInpaint(), + WanVideoUnit_VAP(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), + ] + self.post_units = [ + WanVideoPostUnit_S2V(), + ] + self.model_fn = model_fn_wan_video + + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward + + for block in self.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + audio_processor_config: ModelConfig = None, + redirect_common_files: bool = True, + use_usp: bool = False, + vram_limit: float = None, + wan_paths: list[str] = [], + wan_config_path: str = None, + ): + # Redirect model path + if redirect_common_files: + redirect_dict = { + "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), + "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), + "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), + "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), + } + for model_config in model_configs: + if model_config.origin_file_pattern is None or model_config.model_id is None: + continue + if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: + print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") + model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] + model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] + + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp(device) + import torch.distributed as dist + from ..core.device.npu_compatible_device import get_device_name + if dist.is_available() and dist.is_initialized(): + device = get_device_name() + # Initialize pipeline + + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + + print(f"====== go load wan config ======") + with open(wan_config_path, "r") as f: + config = json.load(f) + dit = WanModel(**config) + + print(f"====== go load wan weight ======") + dit_state_dict = {} + for each in wan_paths: + dit_state_dict.update(load_file(each)) + missing, unexpected = dit.load_state_dict(dit_state_dict, strict=False) + with torch.no_grad(): + miss = set(missing) + for name, p in dit.named_parameters(): + if name in miss: + p.zero_() + for name, b in dit.named_buffers(): + if name in miss: + if b.is_floating_point() or b.is_complex(): + b.zero_() + else: + b.fill_(0) + print(f"====== load wan weight ok ======") + pipe.dit = dit.to(torch.bfloat16) + + + # dit = model_pool.fetch_model("wan_video_dit", index=2) + # if isinstance(dit, list): + # pipe.dit, pipe.dit2 = dit + # else: + # pipe.dit = dit + + pipe.vae = model_pool.fetch_model("wan_video_vae") + pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") + vace = model_pool.fetch_model("wan_video_vace", index=2) + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace + pipe.vap = model_pool.fetch_model("wan_video_vap") + pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + if audio_processor_config is not None: + audio_processor_config.download_if_necessary() + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + + ########## + src_video: Optional[list[Image.Image]] = None, + tgt_video: Optional[list[Image.Image]] = None, + ########## + + # Video-to-video + input_video: Optional[list[Image.Image]] = None, + denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[np.array] = None, + audio_embeds: Optional[torch.Tensor] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + motion_video: Optional[list[Image.Image]] = None, + # ControlNet + control_video: Optional[list[Image.Image]] = None, + reference_image: Optional[Image.Image] = None, + # Camera control + camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, + camera_control_speed: Optional[float] = 1/54, + camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + # VACE + vace_video: Optional[list[Image.Image]] = None, + vace_video_mask: Optional[Image.Image] = None, + vace_reference_image: Optional[Image.Image] = None, + vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, + # VAP + vap_video: Optional[list[Image.Image]] = None, + vap_prompt: Optional[str] = " ", + negative_vap_prompt: Optional[str] = " ", + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 480, + width: Optional[int] = 832, + num_frames=81, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # Speed control + motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh: Optional[float] = None, + tea_cache_model_id: Optional[str] = "", + # progress_bar + progress_bar_cmd=tqdm, + output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "vap_prompt": vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "negative_vap_prompt": negative_vap_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "src_video": src_video, "tgt_video":tgt_video, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, + "vap_video": vap_video, + } + for unit in self.units: + + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + tgt_latent_length = inputs_shared["latents"].shape[2] + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + models["vace"] = self.vace2 + + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + inputs_shared["latents"] = torch.cat([inputs_shared["latents"], inputs_shared["input_latents"]], dim=2) + + # Inference + noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred[:,:,:tgt_latent_length,...], self.scheduler.timesteps[progress_id], inputs_shared["latents"][:,:,:tgt_latent_length,...]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # VACE (TODO: remove it) + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): + if vace_reference_image is not None and isinstance(vace_reference_image, list): + f = len(vace_reference_image) + else: + f = 1 + inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if output_type == "quantized": + video = self.vae_output_to_video(video) + elif output_type == "floatpoint": + pass + self.load_models_to_device([]) + return video + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("src_video", "tgt_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, src_video, tgt_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + if src_video is None: + return {"latents": noise} + pipe.load_models_to_device(self.onload_model_names) + + src_video = pipe.preprocess_video(src_video) + src_latents = pipe.vae.encode(src_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + if tgt_video is not None: + tgt_video = pipe.preprocess_video(tgt_video) + tgt_latents = pipe.vae.encode(tgt_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + # print() + # print(src_latents.shape) + # print(tgt_latents.shape) + # print() + + input_latents = torch.concat([tgt_latents, src_latents], dim=2) + + + if vace_reference_image is not None: + if not isinstance(vace_reference_image, list): + vace_reference_image = [vace_reference_image] + vace_reference_image = pipe.preprocess_video(vace_reference_image) + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + return {"latents": noise, "input_latents": src_latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + + + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), + output_params=("clip_feature", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -y_dim:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width", "reference_image"), + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae", "image_encoder") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + reference_image = reference_image.resize((width, height)) + reference_latents = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} + clip_feature = pipe.preprocess_image(reference_image) + clip_feature = pipe.image_encoder.encode_image([clip_feature]) + return {"reference_latents": reference_latents, "clip_feature": clip_feature} + + + +class WanVideoUnit_FunCameraControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): + if camera_control_direction is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( + camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) + + control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) + + input_image = input_image.resize((width, height)) + input_latents = pipe.preprocess_video([input_image]) + input_latents = pipe.vae.encode(input_latents, device=pipe.device) + y = torch.zeros_like(latents).to(pipe.device) + y[:, :, :1] = input_latents + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"control_camera_latents_input": control_camera_latents_input, "y": y} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("motion_bucket_id",), + output_params=("motion_bucket_id",) + ) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + output_params=("vace_context", "vace_scale"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_video_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) + else: + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) + + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) + inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + + vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + + vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + +class WanVideoUnit_VAP(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder", "vae", "image_encoder"), + input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") + ) + + def encode_prompt(self, pipe: WanVideoPipeline, prompt): + ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("vap_video") is None: + return inputs_shared, inputs_posi, inputs_nega + else: + # 1. encode vap prompt + pipe.load_models_to_device(["text_encoder"]) + vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") + vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) + negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) + inputs_posi.update({"context_vap":vap_prompt_emb}) + inputs_nega.update({"context_vap":negative_vap_prompt_emb}) + # 2. prepare vap image clip embedding + pipe.load_models_to_device(["vae", "image_encoder"]) + vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") + + num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") + + image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) + + vap_clip_context = pipe.image_encoder.encode_image([image_vap]) + if end_image is not None: + vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) + vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_clip_feature":vap_clip_context}) + + # 3. prepare vap latents + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + msk[:, -1:] = 1 + last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) + vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) + else: + vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_video = pipe.preprocess_video(vap_video) + vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + + vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) + inputs_shared.update({"vap_hidden_state":vap_latent}) + + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: WanVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + output_params=("tea_cache",) + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + for name in self.concat_tensor_names: + tensor_posi = inputs_posi.get(name) + tensor_nega = inputs_nega.get(name) + tensor_shared = inputs_shared.get(name) + if tensor_posi is not None and tensor_nega is not None: + inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) + elif tensor_shared is not None: + inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae",), + input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), + output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + if audio_embeds is not None: + return {"audio_embeds": audio_embeds} + pipe.load_models_to_device(["audio_encoder"]) + audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) + if return_all: + return audio_embeds + else: + return {"audio_embeds": audio_embeds[0]} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): + pipe.load_models_to_device(["vae"]) + motion_frames = 73 + kwargs = {} + if motion_video is not None: + assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}" + motion_latents = motion_video + kwargs["drop_motion_frames"] = False + else: + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + kwargs["drop_motion_frames"] = True + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + kwargs.update({"motion_latents": motion_latents}) + return kwargs + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): + if s2v_pose_latents is not None: + return {"s2v_pose_latents": s2v_pose_latents} + if s2v_pose_video is None: + return {"s2v_pose_latents": None} + pipe.load_models_to_device(["vae"]) + infer_frames = num_frames - 1 + input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] + # pad if not enough frames + padding_frames = infer_frames * num_repeats - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + input_videos = input_video.chunk(num_repeats, dim=2) + pose_conds = [] + for r in range(num_repeats): + cond = input_videos[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + pose_conds.append(cond_latents[:,:,1:]) + if return_all: + return pose_conds + else: + return {"s2v_pose_latents": pose_conds[0]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) + s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) + return inputs_shared, inputs_posi, inputs_nega + + @staticmethod + def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): + assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." + shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) + height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] + unit = WanVideoUnit_S2V() + audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) + pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + pose_latents = None if s2v_pose_video is None else pose_latents + return audio_embeds, pose_latents, len(audio_embeds) + + +class WanVideoPostUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + + def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): + if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: + return {} + latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) + return {"latents": latents} + + +class WanVideoUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), + output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") + ) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + output_params=("pose_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + input_params=("animate_face_video",), + output_params=("face_pixel_values"), + ) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return inputs_shared, inputs_posi, inputs_nega + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + output_params=("longcat_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if border_width == 0: + return x + + shift = 0.5 + if not left_bound: + x[:border_width] = (torch.arange(border_width) + shift) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + if batch_size is not None: + B *= batch_size + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + vap: MotWanModel = None, + animate_adapter: WanAnimateAdapter = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + audio_embeds: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + s2v_pose_latents: Optional[torch.Tensor] = None, + vap_hidden_state = None, + vap_clip_feature = None, + context_vap = None, + drop_motion_frames: bool = True, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, + longcat_latents=None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + cfg_merge: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"], + batch_size=2 if cfg_merge else 1 + ) + # LongCat-Video + if isinstance(dit, LongCatVideoTransformer3DModel): + return model_fn_longcat_video( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + longcat_latents=longcat_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + # wan2.2 s2v + if audio_embeds is not None: + return model_fn_wans2v( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + audio_embeds=audio_embeds, + motion_latents=motion_latents, + s2v_pose_latents=s2v_pose_latents, + drop_motion_frames=drop_motion_frames, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) + t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] + t = t_chunks[get_sequence_parallel_rank()] + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + # Camera control + x = dit.patchify(x, control_camera_latents_input) + + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + + # Patchify + f, h, w = x.shape[2:] + first_frame_len = h * w + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + + + # Reference image + if reference_latents is not None: + if len(reference_latents.shape) == 5: + reference_latents = reference_latents[:, :, 0] + reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # VAP + if vap is not None: + # hidden state + x_vap = vap_hidden_state + x_vap = vap.patchify(x_vap) + x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() + # Timestep + clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) + t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) + t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) + + # rope + freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) + + # context + vap_clip_embedding = vap.img_emb(vap_clip_feature) + context_vap = vap.text_embedding(context_vap) + context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) + + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) + + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + def create_custom_forward_vap(block, vap): + def custom_forward(*inputs): + return vap(block, *inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + # Block + if vap is not None and block_id in vap.mot_layers_mapping: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x, x_vap = torch.utils.checkpoint.checkpoint( + create_custom_forward_vap(block, vap), + x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, + use_reentrant=False, + ) + else: + x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) + else: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + + + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, first_frame_len, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs, first_frame_len) + + # VACE + if vace_context is not None and block_id in vace.vace_layers_mapping: + current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) + x = x + current_vace_hint * vace_scale + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + # Remove reference latents + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + x = dit.unpatchify(x, (f, h, w)) + return x + + +def model_fn_longcat_video( + dit: LongCatVideoTransformer3DModel, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + longcat_latents: torch.Tensor = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, +): + if longcat_latents is not None: + latents[:, :, :longcat_latents.shape[2]] = longcat_latents + num_cond_latents = longcat_latents.shape[2] + else: + num_cond_latents = 0 + context = context.unsqueeze(0) + encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) + output = dit( + latents, + timestep, + context, + encoder_attention_mask, + num_cond_latents=num_cond_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + output = -output + output = output.to(latents.dtype) + return output + + +def model_fn_wans2v( + dit, + latents, + timestep, + context, + audio_embeds, + motion_latents, + s2v_pose_latents, + drop_motion_frames=True, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, +): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) + + # x and s2v_pose_latents + s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel + + # reference image + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) + + x = x + dit.trainable_cond_mask(mask).to(x.dtype) + + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] + x = dit.head(x, t[:-1]) + x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/DiffSynth-Studio/diffsynth/pipelines/z_image.py b/DiffSynth-Studio/diffsynth/pipelines/z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5b68730f633a7e429e104c77464ca5680607ca --- /dev/null +++ b/DiffSynth-Studio/diffsynth/pipelines/z_image.py @@ -0,0 +1,669 @@ +import torch, math +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange +import numpy as np +from typing import Union, List, Optional, Tuple, Iterable, Dict + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..core.data.operators import ImageCropAndResize +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora import merge_lora + +from transformers import AutoTokenizer +from ..models.z_image_text_encoder import ZImageTextEncoder +from ..models.z_image_dit import ZImageDiT +from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M +from ..models.z_image_controlnet import ZImageControlNet +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.z_image_image2lora import ZImageImage2LoRAModel + + +class ZImagePipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("Z-Image") + self.text_encoder: ZImageTextEncoder = None + self.dit: ZImageDiT = None + self.vae_encoder: FluxVAEEncoder = None + self.vae_decoder: FluxVAEDecoder = None + self.image_encoder: Siglip2ImageEncoder428M = None + self.controlnet: ZImageControlNet = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: ZImageImage2LoRAModel = None + self.tokenizer: AutoTokenizer = None + self.in_iteration_models = ("dit", "controlnet") + self.units = [ + ZImageUnit_ShapeChecker(), + ZImageUnit_PromptEmbedder(), + ZImageUnit_NoiseInitializer(), + ZImageUnit_InputImageEmbedder(), + ZImageUnit_EditImageAutoResize(), + ZImageUnit_EditImageEmbedderVAE(), + ZImageUnit_EditImageEmbedderSiglip(), + ZImageUnit_PAIControlNet(), + ] + self.model_fn = model_fn_z_image + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = ZImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("z_image_dit") + pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") + pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") + pipe.controlnet = model_pool.fetch_model("z_image_controlnet") + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Edit + edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 8, + sigma_shift: float = None, + # ControlNet + controlnet_inputs: List[ControlNetInput] = None, + # Image to LoRA + image2lora_images: List[Image.Image] = None, + positive_only_lora: Dict[str, torch.Tensor] = None, + # Progress bar + progress_bar_cmd = tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Parameters + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, + "controlnet_inputs": controlnet_inputs, + "image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae_decoder']) + image = self.vae_decoder(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class ZImageUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: ZImagePipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class ZImageUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params=("edit_image",), + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt( + self, + pipe, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = pipe.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def encode_prompt_omni( + self, + pipe, + prompt: Union[str, List[str]], + edit_image=None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + if edit_image is None: + num_condition_images = 0 + elif isinstance(edit_image, list): + num_condition_images = len(edit_image) + else: + num_condition_images = 1 + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = pipe.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def process(self, pipe: ZImagePipeline, prompt, edit_image): + pipe.load_models_to_device(self.onload_model_names) + if hasattr(pipe, "dit") and pipe.dit.siglip_embedder is not None: + # Z-Image-Turbo and Z-Image-Omni-Base use different prompt encoding methods. + # We determine which encoding method to use based on the model architecture. + # If you are using two-stage split training, + # please use `--offload_models` instead of skipping the DiT model loading. + prompt_embeds = self.encode_prompt_omni(pipe, prompt, edit_image, pipe.device) + else: + prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) + return {"prompt_embeds": prompt_embeds} + + +class ZImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: ZImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class ZImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae_encoder(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +class ZImageUnit_EditImageAutoResize(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_image",), + ) + + def process(self, pipe: ZImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + if edit_image_auto_resize is None or not edit_image_auto_resize: + return {} + operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) + if not isinstance(edit_image, list): + edit_image = [edit_image] + edit_image = [operator(i) for i in edit_image] + return {"edit_image": edit_image} + + +class ZImageUnit_EditImageEmbedderSiglip(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_embeds",), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_emb = [] + for image_ in edit_image: + image_emb.append(pipe.image_encoder(image_, device=pipe.device)) + return {"image_embeds": image_emb} + + +class ZImageUnit_EditImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image",), + output_params=("image_latents",), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, edit_image): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if not isinstance(edit_image, list): + edit_image = [edit_image] + image_latents = [] + for image_ in edit_image: + image_ = pipe.preprocess_image(image_) + image_latents.append(pipe.vae_encoder(image_)) + return {"image_latents": image_latents} + + +class ZImageUnit_PAIControlNet(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("controlnet_inputs", "height", "width"), + output_params=("control_context", "control_scale"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width): + if controlnet_inputs is None: + return {} + if len(controlnet_inputs) != 1: + print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.") + controlnet_input = controlnet_inputs[0] + pipe.load_models_to_device(self.onload_model_names) + + control_image = controlnet_input.image + if control_image is not None: + control_image = pipe.preprocess_image(control_image) + control_latents = pipe.vae_encoder(control_image) + else: + control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1 + + inpaint_mask = controlnet_input.inpaint_mask + if inpaint_mask is not None: + inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1) + inpaint_image = controlnet_input.inpaint_image + inpaint_image = pipe.preprocess_image(inpaint_image) + inpaint_image = inpaint_image * (inpaint_mask < 0.5) + inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1] + else: + inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device) + inpaint_latent = pipe.vae_encoder(inpaint_image) + + control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1) + control_context = rearrange(control_context, "B C H W -> B C 1 H W") + return {"control_context": control_context, "control_scale": controlnet_input.scale} + + +def model_fn_z_image( + dit: ZImageDiT, + controlnet: ZImageControlNet = None, + latents=None, + timestep=None, + prompt_embeds=None, + image_embeds=None, + image_latents=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + # Due to the complex and verbose codebase of Z-Image, + # we are temporarily using this inelegant structure. + # We will refactor this part in the future (if time permits). + if dit.siglip_embedder is None: + return model_fn_z_image_turbo( + dit, + controlnet=controlnet, + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + image_embeds=image_embeds, + image_latents=image_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + **kwargs, + ) + latents = [rearrange(latents, "B C H W -> C B H W")] + if dit.siglip_embedder is not None: + if image_latents is not None: + image_latents = [rearrange(image_latent, "B C H W -> C B H W") for image_latent in image_latents] + latents = [image_latents + latents] + image_noise_mask = [[0] * len(image_latents) + [1]] + else: + latents = [latents] + image_noise_mask = [[1]] + image_embeds = [image_embeds] + else: + image_noise_mask = None + timestep = (1000 - timestep) / 1000 + model_output = dit( + latents, + timestep, + prompt_embeds, + siglip_feats=image_embeds, + image_noise_mask=image_noise_mask, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + )[0] + model_output = -model_output + model_output = rearrange(model_output, "C B H W -> B C H W") + return model_output + + +class ZImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x",), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + return x + + def process(self, pipe: ZImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x} + + +class ZImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x",), + output_params=("lora",), + onload_model_names=("image2lora_style",), + ) + + def process(self, pipe: ZImagePipeline, image2lora_x): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +def model_fn_z_image_turbo( + dit: ZImageDiT, + controlnet: ZImageControlNet = None, + latents=None, + timestep=None, + prompt_embeds=None, + image_embeds=None, + image_latents=None, + control_context=None, + control_scale=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + while isinstance(prompt_embeds, list): + prompt_embeds = prompt_embeds[0] + while isinstance(latents, list): + latents = latents[0] + while isinstance(image_embeds, list): + image_embeds = image_embeds[0] + + # Timestep + timestep = 1000 - timestep + t_noisy = dit.t_embedder(timestep) + t_clean = dit.t_embedder(torch.ones_like(timestep) * 1000) + + # Patchify + latents = rearrange(latents, "B C H W -> C B H W") + x, cap_feats, patch_metadata = dit.patchify_and_embed([latents], [prompt_embeds]) + x = x[0] + cap_feats = cap_feats[0] + + # Noise refine + x = dit.all_x_embedder["2-1"](x) + x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device) + x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) + x = rearrange(x, "L C -> 1 L C") + x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) + refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( + dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + for layer_id, layer in enumerate(dit.noise_refiner): + x = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=x, + attn_mask=None, + freqs_cis=x_freqs_cis, + adaln_input=t_noisy, + ) + if control_context is not None: + x = x + refiner_hints[layer_id] * control_scale + + # Prompt refine + cap_feats = dit.cap_embedder(cap_feats) + cap_feats[torch.cat(patch_metadata.get("cap_pad_mask"))] = dit.cap_pad_token.to(dtype=x.dtype, device=x.device) + cap_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("cap_pos_ids"), dim=0)) + cap_feats = rearrange(cap_feats, "L C -> 1 L C") + cap_freqs_cis = rearrange(cap_freqs_cis, "L C -> 1 L C") + + for layer in dit.context_refiner: + cap_feats = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=cap_feats, + attn_mask=None, + freqs_cis=cap_freqs_cis, + ) + + # Unified + unified = torch.cat([x, cap_feats], dim=1) + unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) + + if control_context is not None: + kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) + hints = controlnet.forward_layers( + unified, cap_feats, control_context, control_context_item_seqlens, kwargs, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + for layer_id, layer in enumerate(dit.layers): + unified = gradient_checkpoint_forward( + layer, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=unified, + attn_mask=None, + freqs_cis=unified_freqs_cis, + adaln_input=t_noisy, + ) + if control_context is not None: + if layer_id in controlnet.control_layers_mapping: + unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale + + # Output + unified = dit.all_final_layer["2-1"](unified, t_noisy) + x = dit.unpatchify([unified[0]], patch_metadata.get("x_size"))[0] + x = rearrange(x, "C B H W -> B C H W") + x = -x + return x diff --git a/DiffSynth-Studio/diffsynth/utils/controlnet/__init__.py b/DiffSynth-Studio/diffsynth/utils/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df23b6c61b99319f1f41d6448b0c52ffd03b9f25 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/controlnet/__init__.py @@ -0,0 +1,2 @@ +from .controlnet_input import ControlNetInput +from .annotator import Annotator diff --git a/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a92cd5e55757fb10b4553521ba93e4b988e0a97 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/annotator.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/annotator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea5862d47283fd6fc1121c2a56404719292b245 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/annotator.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/controlnet_input.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/controlnet_input.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b736d7ed2c070319282fd325b85093e8eb61f9f Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/controlnet/__pycache__/controlnet_input.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/controlnet/annotator.py b/DiffSynth-Studio/diffsynth/utils/controlnet/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..cb737385f75bf1edd681c24c3118b0ac0d79e185 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/controlnet/annotator.py @@ -0,0 +1,63 @@ +from typing_extensions import Literal, TypeAlias + +from diffsynth.core.device.npu_compatible_device import get_device_type + +Processor_id: TypeAlias = Literal[ + "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" +] + +class Annotator: + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False): + if not skip_processor: + if processor_id == "canny": + from controlnet_aux.processor import CannyDetector + self.processor = CannyDetector() + elif processor_id == "depth": + from controlnet_aux.processor import MidasDetector + self.processor = MidasDetector.from_pretrained(model_path).to(device) + elif processor_id == "softedge": + from controlnet_aux.processor import HEDdetector + self.processor = HEDdetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart": + from controlnet_aux.processor import LineartDetector + self.processor = LineartDetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart_anime": + from controlnet_aux.processor import LineartAnimeDetector + self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) + elif processor_id == "openpose": + from controlnet_aux.processor import OpenposeDetector + self.processor = OpenposeDetector.from_pretrained(model_path).to(device) + elif processor_id == "normal": + from controlnet_aux.processor import NormalBaeDetector + self.processor = NormalBaeDetector.from_pretrained(model_path).to(device) + elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint": + self.processor = None + else: + raise ValueError(f"Unsupported processor_id: {processor_id}") + else: + self.processor = None + + self.processor_id = processor_id + self.detect_resolution = detect_resolution + + def to(self,device): + if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"): + + self.processor.model.to(device) + + def __call__(self, image, mask=None): + width, height = image.size + if self.processor_id == "openpose": + kwargs = { + "include_body": True, + "include_hand": True, + "include_face": True + } + else: + kwargs = {} + if self.processor is not None: + detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) + image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) + image = image.resize((width, height)) + return image + diff --git a/DiffSynth-Studio/diffsynth/utils/controlnet/controlnet_input.py b/DiffSynth-Studio/diffsynth/utils/controlnet/controlnet_input.py new file mode 100644 index 0000000000000000000000000000000000000000..a79064bb51fa3625ca692a183544ec9720ca33b9 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/controlnet/controlnet_input.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from PIL import Image + + +@dataclass +class ControlNetInput: + controlnet_id: int = 0 + scale: float = 1.0 + start: float = 1.0 + end: float = 0.0 + image: Image.Image = None + inpaint_image: Image.Image = None + inpaint_mask: Image.Image = None + processor_id: str = None diff --git a/DiffSynth-Studio/diffsynth/utils/data/__init__.py b/DiffSynth-Studio/diffsynth/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b9daa41bea9d36e012d52a1d280d1cf8d92850 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/data/__init__.py @@ -0,0 +1,217 @@ +import imageio, os +import numpy as np +from PIL import Image +from tqdm import tqdm +import subprocess +import shutil + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return Image.open(self.file_list[item]).convert("RGB") + + def __del__(self): + pass + + +def crop_and_resize(image, height, width): + image = np.array(image) + image_height, image_width, _ = image.shape + if image_height / image_width < height / width: + croped_width = int(image_height / height * width) + left = (image_width - croped_width) // 2 + image = image[:, left: left+croped_width] + image = Image.fromarray(image).resize((width, height)) + else: + croped_height = int(image_width / width * height) + left = (image_height - croped_height) // 2 + image = image[left: left+croped_height, :] + image = Image.fromarray(image).resize((width, height)) + return image + + +class VideoData: + def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.set_shape(height, width) + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + height, width, _ = self.__getitem__(0).shape + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + width, height = frame.size + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = crop_and_resize(frame, self.height, self.width) + return frame + + def __del__(self): + pass + + def save_images(self, folder): + os.makedirs(folder, exist_ok=True) + for i in tqdm(range(self.__len__()), desc="Saving images"): + frame = self.__getitem__(i) + frame.save(os.path.join(folder, f"{i}.png")) + + +def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + +def save_frames(frames, save_path): + os.makedirs(save_path, exist_ok=True) + for i, frame in enumerate(tqdm(frames, desc="Saving images")): + frame.save(os.path.join(save_path, f"{i}.png")) + + +def merge_video_audio(video_path: str, audio_path: str): + # TODO: may need a in-python implementation to avoid subprocess dependency + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, + and overwrite the original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + + # check + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # execute the command + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + print(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + print(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + print(f"merge_video_audio failed with error: {e}") + + +def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None): + save_video(frames, save_path, fps, quality, ffmpeg_params) + merge_video_audio(save_path, audio_path) diff --git a/DiffSynth-Studio/diffsynth/utils/lora/__init__.py b/DiffSynth-Studio/diffsynth/utils/lora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb5901acba99ed8490079b8ebaeb6991ae3f59d --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/lora/__init__.py @@ -0,0 +1,3 @@ +from .general import GeneralLoRALoader +from .merge import merge_lora +from .reset_rank import reset_lora_rank \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a48f5b2e84f2e31bd555bb1a11c102cca1df436 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/general.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/general.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b61867b6d0cb07d5f906c90108d926962e7aef37 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/general.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/merge.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/merge.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe73ae2d5ea1549ab83a5987d32a3ce54e8b82b0 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/merge.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/reset_rank.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/reset_rank.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103817c5ba7a4de8c4696e75a85569ef435b7346 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/lora/__pycache__/reset_rank.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/lora/flux.py b/DiffSynth-Studio/diffsynth/utils/lora/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..97599b652a8cb97a004f8d3264d1ac4716612260 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/lora/flux.py @@ -0,0 +1,302 @@ +from .general import GeneralLoRALoader +import torch, math + + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.weight", + } + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().fuse_lora_to_base_model(model, state_dict_lora, alpha) + + def convert_state_dict(self, state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + + mlp = mlp.to(device=state_dict_[name].device) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ + + +class FluxLoRAConverter: + def __init__(self): + pass + + @staticmethod + def align_to_opensource_format(state_dict, alpha=None): + prefix_rename_dict = { + "single_blocks": "lora_unet_single_blocks", + "blocks": "lora_unet_double_blocks", + } + middle_rename_dict = { + "norm.linear": "modulation_lin", + "to_qkv_mlp": "linear1", + "proj_out": "linear2", + + "norm1_a.linear": "img_mod_lin", + "norm1_b.linear": "txt_mod_lin", + "attn.a_to_qkv": "img_attn_qkv", + "attn.b_to_qkv": "txt_attn_qkv", + "attn.a_to_out": "img_attn_proj", + "attn.b_to_out": "txt_attn_proj", + "ff_a.0": "img_mlp_0", + "ff_a.2": "img_mlp_2", + "ff_b.0": "txt_mlp_0", + "ff_b.2": "txt_mlp_2", + } + suffix_rename_dict = { + "lora_B.weight": "lora_up.weight", + "lora_A.weight": "lora_down.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + names = name.split(".") + if names[-2] != "lora_A" and names[-2] != "lora_B": + names.pop(-2) + prefix = names[0] + middle = ".".join(names[2:-2]) + suffix = ".".join(names[-2:]) + block_id = names[1] + if middle not in middle_rename_dict: + continue + rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix] + state_dict_[rename] = param + if rename.endswith("lora_up.weight"): + lora_alpha = alpha if alpha is not None else param.shape[-1] + state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0] + return state_dict_ + + @staticmethod + def align_to_diffsynth_format(state_dict): + rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + def guess_block_id(name): + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + return None, None + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name) + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/lora/general.py b/DiffSynth-Studio/diffsynth/utils/lora/general.py new file mode 100644 index 0000000000000000000000000000000000000000..624549d518fb8f2a43b04625b268cbab4441a21a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/lora/general.py @@ -0,0 +1,62 @@ +import torch + + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_up." in key: + lora_A_key = "lora_down" + lora_B_key = "lora_up" + else: + lora_A_key = "lora_A" + lora_B_key = "lora_B" + if lora_B_key not in key: + continue + keys = key.split(".") + if len(keys) > keys.index(lora_B_key) + 2: + keys.pop(keys.index(lora_B_key) + 1) + keys.pop(keys.index(lora_B_key)) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(lora_B_key, lora_A_key)) + return lora_name_dict + + + def convert_state_dict(self, state_dict, suffix=".weight"): + name_dict = self.get_name_dict(state_dict) + state_dict_ = {} + for name in name_dict: + weight_up = state_dict[name_dict[name][0]] + weight_down = state_dict[name_dict[name][1]] + state_dict_[name + f".lora_B{suffix}"] = weight_up + state_dict_[name + f".lora_A{suffix}"] = weight_down + return state_dict_ + + + def fuse_lora_to_base_model(self, model: torch.nn.Module, state_dict, alpha=1.0): + updated_num = 0 + state_dict = self.convert_state_dict(state_dict) + lora_layer_names = set([i.replace(".lora_B.weight", "") for i in state_dict if i.endswith(".lora_B.weight")]) + for name, module in model.named_modules(): + if name in lora_layer_names: + weight_up = state_dict[name + ".lora_B.weight"].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict[name + ".lora_A.weight"].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict_base = module.state_dict() + state_dict_base["weight"] = state_dict_base["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict_base) + updated_num += 1 + print(f"{updated_num} tensors are fused by LoRA. Fused LoRA layers cannot be cleared by `pipe.clear_lora()`.") diff --git a/DiffSynth-Studio/diffsynth/utils/lora/merge.py b/DiffSynth-Studio/diffsynth/utils/lora/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..61904ff4bcebc6c344c23f26073aec292355217c --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/lora/merge.py @@ -0,0 +1,20 @@ +import torch +from typing import Dict, List + + +def merge_lora_weight(tensors_A, tensors_B): + lora_A = torch.concat(tensors_A, dim=0) + lora_B = torch.concat(tensors_B, dim=1) + return lora_A, lora_B + + +def merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1): + lora_merged = {} + keys = [i for i in loras[0].keys() if ".lora_A." in i] + for key in keys: + tensors_A = [lora[key] for lora in loras] + tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] + lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) + lora_merged[key] = lora_A * alpha + lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B + return lora_merged diff --git a/DiffSynth-Studio/diffsynth/utils/lora/reset_rank.py b/DiffSynth-Studio/diffsynth/utils/lora/reset_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..9522b043ff962bc050fa79596197f00abf3877b0 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/lora/reset_rank.py @@ -0,0 +1,20 @@ +import torch + +def decomposite(tensor_A, tensor_B, rank): + dtype, device = tensor_A.dtype, tensor_A.device + weight = tensor_B @ tensor_A + U, S, V = torch.pca_lowrank(weight.float(), q=rank) + tensor_A = (V.T).to(dtype=dtype, device=device).contiguous() + tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous() + return tensor_A, tensor_B + +def reset_lora_rank(lora, rank): + lora_merged = {} + keys = [i for i in lora.keys() if ".lora_A." in i] + for key in keys: + tensor_A = lora[key] + tensor_B = lora[key.replace(".lora_A.", ".lora_B.")] + tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank) + lora_merged[key] = tensor_A + lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B + return lora_merged \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__init__.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__pycache__/__init__.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39fe58f7de0c384ed34ac79e58a744cb22e5158e Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__pycache__/__init__.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__pycache__/wan_video_vae.cpython-39.pyc b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__pycache__/wan_video_vae.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a459931f668f11727d4a2ec8fd1293d02f6b475 Binary files /dev/null and b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/__pycache__/wan_video_vae.cpython-39.pyc differ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux2_text_encoder.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux2_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0975e62a35021c697192ad054f0e3aff42289292 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux2_text_encoder.py @@ -0,0 +1,17 @@ +def Flux2TextEncoderStateDictConverter(state_dict): + rename_dict = { + "multi_modal_projector.linear_1.weight": "model.multi_modal_projector.linear_1.weight", + "multi_modal_projector.linear_2.weight": "model.multi_modal_projector.linear_2.weight", + "multi_modal_projector.norm.weight": "model.multi_modal_projector.norm.weight", + "multi_modal_projector.patch_merger.merging_layer.weight": "model.multi_modal_projector.patch_merger.merging_layer.weight", + "language_model.lm_head.weight": "lm_head.weight", + } + state_dict_ = {} + for k in state_dict: + k_ = k + k_ = k_.replace("language_model.model", "model.language_model") + k_ = k_.replace("vision_tower", "model.vision_tower") + if k_ in rename_dict: + k_ = rename_dict[k_] + state_dict_[k_] = state_dict[k] + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_controlnet.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..15f9447d22bc0ebc2dbb3d2eac8dbf0bd78e4151 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_controlnet.py @@ -0,0 +1,103 @@ +import torch + + +def FluxControlNetStateDictConverter(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_dit.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..f808b60defd800ff97ec78fec1dac6f472038cb7 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_dit.py @@ -0,0 +1,197 @@ +import torch + + +def FluxDiTStateDictConverter(state_dict): + is_nexus_gen = sum([key.startswith("pipe.dit.") for key in state_dict]) > 0 + if is_nexus_gen: + dit_state_dict = {} + for key in state_dict: + if key.startswith('pipe.dit.'): + param = state_dict[key] + new_key = key.replace("pipe.dit.", "") + if new_key.startswith("final_norm_out.linear."): + param = torch.concat([param[3072:], param[:3072]], dim=0) + dit_state_dict[new_key] = param + return dit_state_dict + + rename_dict = { + "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", + "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", + "time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias", + "time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight", + "txt_in.bias": "context_embedder.bias", + "txt_in.weight": "context_embedder.weight", + "vector_in.in_layer.bias": "pooled_text_embedder.0.bias", + "vector_in.in_layer.weight": "pooled_text_embedder.0.weight", + "vector_in.out_layer.bias": "pooled_text_embedder.2.bias", + "vector_in.out_layer.weight": "pooled_text_embedder.2.weight", + "final_layer.linear.bias": "final_proj_out.bias", + "final_layer.linear.weight": "final_proj_out.weight", + "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias", + "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight", + "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias", + "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight", + "img_in.bias": "x_embedder.bias", + "img_in.weight": "x_embedder.weight", + "final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight", + "final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias", + } + suffix_rename_dict = { + "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight", + "img_attn.norm.query_norm.scale": "attn.norm_q_a.weight", + "img_attn.proj.bias": "attn.a_to_out.bias", + "img_attn.proj.weight": "attn.a_to_out.weight", + "img_attn.qkv.bias": "attn.a_to_qkv.bias", + "img_attn.qkv.weight": "attn.a_to_qkv.weight", + "img_mlp.0.bias": "ff_a.0.bias", + "img_mlp.0.weight": "ff_a.0.weight", + "img_mlp.2.bias": "ff_a.2.bias", + "img_mlp.2.weight": "ff_a.2.weight", + "img_mod.lin.bias": "norm1_a.linear.bias", + "img_mod.lin.weight": "norm1_a.linear.weight", + "txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight", + "txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight", + "txt_attn.proj.bias": "attn.b_to_out.bias", + "txt_attn.proj.weight": "attn.b_to_out.weight", + "txt_attn.qkv.bias": "attn.b_to_qkv.bias", + "txt_attn.qkv.weight": "attn.b_to_qkv.weight", + "txt_mlp.0.bias": "ff_b.0.bias", + "txt_mlp.0.weight": "ff_b.0.weight", + "txt_mlp.2.bias": "ff_b.2.bias", + "txt_mlp.2.weight": "ff_b.2.weight", + "txt_mod.lin.bias": "norm1_b.linear.bias", + "txt_mod.lin.weight": "norm1_b.linear.weight", + + "linear1.bias": "to_qkv_mlp.bias", + "linear1.weight": "to_qkv_mlp.weight", + "linear2.bias": "proj_out.bias", + "linear2.weight": "proj_out.weight", + "modulation.lin.bias": "norm.linear.bias", + "modulation.lin.weight": "norm.linear.weight", + "norm.key_norm.scale": "norm_k_a.weight", + "norm.query_norm.scale": "norm_q_a.weight", + } + state_dict_ = {} + for name in state_dict: + original_name = name + if name.startswith("model.diffusion_model."): + name = name[len("model.diffusion_model."):] + names = name.split(".") + if name in rename_dict: + rename = rename_dict[name] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "double_blocks": + rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "single_blocks": + if ".".join(names[2:]) in suffix_rename_dict: + rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + else: + pass + return state_dict_ + + +def FluxDiTStateDictConverterFromDiffusers(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + if global_rename_dict[prefix] == "final_norm_out.linear": + param = torch.concat([param[3072:], param[:3072]], dim=0) + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + pass + else: + pass + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + mlp = torch.zeros(4 * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_infiniteyou.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_infiniteyou.py new file mode 100644 index 0000000000000000000000000000000000000000..7025b392d54c5b4844ed3b3387bd010217897f4a --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_infiniteyou.py @@ -0,0 +1,2 @@ +def FluxInfiniteYouImageProjectorStateDictConverter(state_dict): + return state_dict['image_proj'] \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_ipadapter.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..86dfb133655fbe9c33c84b419706a103cec96b1b --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_ipadapter.py @@ -0,0 +1,32 @@ +def FluxIpAdapterStateDictConverter(state_dict): + state_dict_ = {} + + if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict): + for name, param in state_dict["ip_adapter"].items(): + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = param + + if "image_proj" in state_dict: + for name, param in state_dict["image_proj"].items(): + name_ = "image_proj." + name + state_dict_[name_] = param + return state_dict_ + + for key, value in state_dict.items(): + if key.startswith("image_proj."): + state_dict_[key] = value + elif key.startswith("ip_adapter."): + new_key = key.replace("ip_adapter.", "ipadapter_modules.") + state_dict_[new_key] = value + else: + pass + + return state_dict_ + + +def SiglipStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if key.startswith("vision_model."): + new_state_dict[key] = state_dict[key] + return new_state_dict \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..aa018aa5c570cc67f4856002e8f1f83f18998e07 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py @@ -0,0 +1,31 @@ +def FluxTextEncoderClipStateDictConverter(state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias", + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..d35eb831d2a7b1d48eee747d251d6cfb6ad508ef --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py @@ -0,0 +1,4 @@ +def FluxTextEncoderT5StateDictConverter(state_dict): + state_dict_ = {i: state_dict[i] for i in state_dict} + state_dict_["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_vae.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6547f18f1e1cfe69d0cf4ef43860702812d25fab --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/flux_vae.py @@ -0,0 +1,382 @@ +def FluxVAEEncoderStateDictConverter(state_dict): + rename_dict = { + "encoder.conv_in.bias": "conv_in.bias", + "encoder.conv_in.weight": "conv_in.weight", + "encoder.conv_out.bias": "conv_out.bias", + "encoder.conv_out.weight": "conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", + "encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", + "encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", + "encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", + "encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", + "encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", + "encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", + "encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", + "encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", + "encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", + "encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", + "encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", + "encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", + "encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", + "encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", + "encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", + "encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", + "encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", + "encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", + "encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", + "encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", + "encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", + "encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", + "encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", + "encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", + "encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", + "encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", + "encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", + "encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", + "encoder.norm_out.bias": "conv_norm_out.bias", + "encoder.norm_out.weight": "conv_norm_out.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEDecoderStateDictConverter(state_dict): + rename_dict = { + "decoder.conv_in.bias": "conv_in.bias", + "decoder.conv_in.weight": "conv_in.weight", + "decoder.conv_out.bias": "conv_out.bias", + "decoder.conv_out.weight": "conv_out.weight", + "decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias", + "decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight", + "decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias", + "decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight", + "decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias", + "decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight", + "decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias", + "decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight", + "decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias", + "decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight", + "decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias", + "decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight", + "decoder.norm_out.bias": "conv_norm_out.bias", + "decoder.norm_out.weight": "conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight", + "decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias", + "decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight", + "decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight", + "decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias", + "decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight", + "decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight", + "decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias", + "decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + state_dict_[rename_dict[name]] = param + return state_dict_ + + +def FluxVAEEncoderStateDictConverterDiffusers(state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "quant_conv": "quant_conv", + "encoder.conv_in": "conv_in", + "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm", + "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q", + "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k", + "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v", + "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out", + "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1", + "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1", + "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2", + "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2", + "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1", + "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1", + "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2", + "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2", + "encoder.conv_norm_out": "conv_norm_out", + "encoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("encoder.down_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + return state_dict_ + + +def FluxVAEDecoderStateDictConverterDiffusers(state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "post_quant_conv": "post_quant_conv", + "decoder.conv_in": "conv_in", + "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm", + "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q", + "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k", + "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v", + "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out", + "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1", + "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1", + "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2", + "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2", + "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1", + "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1", + "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2", + "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2", + "decoder.conv_norm_out": "conv_norm_out", + "decoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("decoder.up_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/nexus_gen.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/nexus_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..aff853d0e76dd1f130ce44241462f61af84370db --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/nexus_gen.py @@ -0,0 +1,6 @@ +def NexusGenAutoregressiveModelStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + new_state_dict["model." + key] = value + return new_state_dict \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/nexus_gen_projector.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/nexus_gen_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a44665551ba4a97d063de94f6025c8c48989fd --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/nexus_gen_projector.py @@ -0,0 +1,15 @@ +def NexusGenMergerStateDictConverter(state_dict): + merger_state_dict = {} + for key in state_dict: + if key.startswith('embedding_merger.'): + value = state_dict[key] + new_key = key.replace("embedding_merger.", "") + merger_state_dict[new_key] = value + return merger_state_dict + +def NexusGenAdapterStateDictConverter(state_dict): + adapter_state_dict = {} + for key in state_dict: + if key.startswith('adapter.'): + adapter_state_dict[key] = state_dict[key] + return adapter_state_dict \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e8192a1f2a959685cf1fa5af40824bd896454141 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/qwen_image_text_encoder.py @@ -0,0 +1,10 @@ +def QwenImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for k in state_dict: + v = state_dict[k] + if k.startswith("visual."): + k = "model." + k + elif k.startswith("model."): + k = k.replace("model.", "model.language_model.") + state_dict_[k] = v + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/step1x_connector.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/step1x_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..35a2a4167b16ea5cc16aaa1b0f20575bc2918bbf --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/step1x_connector.py @@ -0,0 +1,7 @@ +def Qwen2ConnectorStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("connector."): + name_ = name[len("connector."):] + state_dict_[name_] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea69f4e6696bbef6de197abaa031ea8cc5b398e --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_animate_adapter.py @@ -0,0 +1,6 @@ +def WanAnimateAdapterStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"): + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_dit.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..c7716dad52e42ebf76f98dd85511ac0a04b3d3b3 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_dit.py @@ -0,0 +1,83 @@ +def WanVideoDiTFromDiffusers(state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + state_dict_[rename_dict[name]] = state_dict[name] + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = state_dict[name] + return state_dict_ + + +def WanVideoDiTStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vace"): + continue + if name.split(".")[0] in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]: + continue + name_ = name + if name_.startswith("model."): + name_ = name_[len("model."):] + state_dict_[name_] = state_dict[name] + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb7e9bfce50e88601f8876341ac56645a8e5913 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_image_encoder.py @@ -0,0 +1,8 @@ +def WanImageEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("textual."): + continue + name_ = "model." + name + state_dict_[name_] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_mot.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_mot.py new file mode 100644 index 0000000000000000000000000000000000000000..12b42d7db752fca1cb24c0f16217deab925916f5 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_mot.py @@ -0,0 +1,78 @@ +def WanVideoMotStateDictConverter(state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", + "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", + "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", + "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", + "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", + "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", + "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", + "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", + "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", + "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", + "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", + "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) + mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)} + state_dict_ = {} + for name in state_dict: + if "_mot_ref" not in name: + continue + param = state_dict[name] + name = name.replace("_mot_ref", "") + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + if name.split(".")[1].isdigit(): + block_id = int(name.split(".")[1]) + name = name.replace(str(block_id), str(mot_layers_mapping[block_id])) + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_vace.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfef6998f47ac7d3640b28b99f109e7f04baeba --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_vace.py @@ -0,0 +1,3 @@ +def VaceWanModelDictConverter(state_dict): + state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith("vace")} + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_vae.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..76a430e1bd4575e0ae06234de23b620d4877566f --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wan_video_vae.py @@ -0,0 +1,7 @@ +def WanVideoVAEStateDictConverter(state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa12c0d4ff7fc166eea1f804cb645be9aa28776 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/wans2v_audio_encoder.py @@ -0,0 +1,12 @@ +def WanS2VAudioEncoderStateDictConverter(state_dict): + rename_dict = { + "model.wav2vec2.encoder.pos_conv_embed.conv.weight_g": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0", + "model.wav2vec2.encoder.pos_conv_embed.conv.weight_v": "model.wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1", + } + state_dict_ = {} + for name in state_dict: + name_ = "model." + name + if name_ in rename_dict: + name_ = rename_dict[name_] + state_dict_[name_] = state_dict[name] + return state_dict_ diff --git a/DiffSynth-Studio/diffsynth/utils/state_dict_converters/z_image_text_encoder.py b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/z_image_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b11461345e808edd2c4ca793419ae137b70bfbc9 --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/state_dict_converters/z_image_text_encoder.py @@ -0,0 +1,6 @@ +def ZImageTextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name != "lm_head.weight": + state_dict_[name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/DiffSynth-Studio/diffsynth/utils/xfuser/__init__.py b/DiffSynth-Studio/diffsynth/utils/xfuser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13dd178e2d47bf58de1bfca6d21052d0563c70ca --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/xfuser/__init__.py @@ -0,0 +1 @@ +from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp diff --git a/DiffSynth-Studio/diffsynth/utils/xfuser/xdit_context_parallel.py b/DiffSynth-Studio/diffsynth/utils/xfuser/xdit_context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..21dc3b33c854aebf5deb723420abcf3913cbe6fd --- /dev/null +++ b/DiffSynth-Studio/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -0,0 +1,146 @@ +import torch +from typing import Optional +from einops import rearrange +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention +from ...core.device import parse_nccl_backend, parse_device_type + + +def initialize_usp(device_type): + import torch.distributed as dist + from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment + dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), + ) + getattr(torch, device_type).set_device(dist.get_rank()) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] + + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) + return x_out.to(x.dtype) + +def usp_dit_forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # Context Parallel + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] + chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] + x = chunks[get_sequence_parallel_rank()] + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + x = x[:, :-pad_shape] if pad_shape > 0 else x + + # unpatchify + x = self.unpatchify(x, (f, h, w)) + return x + + +def usp_attn_forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + + x = xFuserLongContextAttention()( + None, + query=q, + key=k, + value=v, + ) + x = x.flatten(2) + + del q, k, v + getattr(torch, parse_device_type(x.device)).empty_cache() + return self.o(x) \ No newline at end of file diff --git a/README.md b/README.md index 448f7158a17b024e6c85b3bba7b982fe6e93d4a7..a3109f43f0bff9a01cc0c89e874b32d36f6606df 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ emoji: 📈 colorFrom: yellow colorTo: yellow sdk: gradio -sdk_version: 6.13.0 +sdk_version: 4.44.1 app_file: app.py pinned: false --- diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..be67b6dd778e27117872f421f1df21bb2a22c8b9 --- /dev/null +++ b/app.py @@ -0,0 +1,460 @@ +import torch, os, re +import numpy as np +import gradio as gr +from PIL import Image +from scipy.spatial.transform import Rotation +import cv2, sys +from huggingface_hub import snapshot_download +import spaces +os.system("python -m conda install pytorch3d-0.7.8-py39_cu121_pyt241.tar.bz2") +# ===== VGGT ===== +sys.path.append(os.path.join(os.getcwd(), "/vggt")) +from vggt.models.vggt import VGGT +from vggt.utils.load_fn import load_and_preprocess_images +from vggt.utils.pose_enc import pose_encoding_to_extri_intri + + + +# ===== Wan ===== +sys.path.append(os.path.join(os.getcwd(), "/DiffSynth-Studio")) +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from safetensors.torch import load_file + +# ===== PyTorch3D ===== +from pytorch3d.structures import Pointclouds +from pytorch3d.renderer import ( + PerspectiveCameras, + PointsRasterizationSettings, + PointsRenderer, + PointsRasterizer, + AlphaCompositor, +) + + +def todevice(batch, device, callback=None, non_blocking=False): + ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + ''' + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == 'numpy': + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +def to_numpy(x): return todevice(x, 'numpy') + + + +# ========================= +# Global configs (CHANGE THESE PATHS) +# ========================= +hf_token = os.getenv("HF_TOKEN") + +snapshot_download(repo_id="facebook/VGGT-1B", local_dir="./VGGT_PATH", token=hf_token) +snapshot_download(repo_id="Wan-AI/Wan2.2-TI2V-5B", local_dir="./WAN_MODEL_DIR", token=hf_token) +snapshot_download(repo_id="123123aa123/UniGeo", local_dir="./LORA_PATH", token=hf_token) + +VGGT_PATH = "./VGGT_PATH" +WAN_MODEL_DIR = "./WAN_MODEL_DIR" +LORA_PATH = "./LORA_PATH" +WAN_CONFIG_PATH = "./my_config.json" + + +# ========================= +# Global models +# ========================= +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 + +vggt_model = None +wan_pipe = None + + +# ========================= +# Load models once +# ========================= +def load_models(): + global vggt_model, wan_pipe + + if vggt_model is None: + print("Loading VGGT...") + vggt_model = VGGT.from_pretrained(VGGT_PATH).to(device).eval() + + if wan_pipe is None: + print("Loading Wan...") + + wan_paths = [ + os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00001-of-00003.safetensors"), + os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00002-of-00003.safetensors"), + os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00003-of-00003.safetensors"), + ] + + wan_pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(path=os.path.join(WAN_MODEL_DIR, "models_t5_umt5-xxl-enc-bf16.pth")), + ModelConfig(path=os.path.join(WAN_MODEL_DIR, "Wan2.2_VAE.pth")), + ], + tokenizer_config=ModelConfig(path=os.path.join(WAN_MODEL_DIR, "google/umt5-xxl/")), + wan_paths=wan_paths, + wan_config_path=WAN_CONFIG_PATH + ) + + ckpt = load_file(LORA_PATH) + lora_sd, adapter_sd = {}, {} + + for k, v in ckpt.items(): + if ".lora_" in k: + lora_sd[k] = v + elif "i2v_adapter" in k: + adapter_sd[k] = v + + wan_pipe.load_lora(wan_pipe.dit, state_dict=lora_sd, alpha=1) + wan_pipe.dit.load_state_dict(adapter_sd, strict=False) + + wan_pipe.to(device) + wan_pipe.to(dtype=torch.bfloat16) + + +# ========================= +# Renderer +# ========================= +def setup_renderer(cameras, image_size): + raster_settings = PointsRasterizationSettings( + image_size=image_size, + radius = 0.01, + points_per_pixel = 10, + bin_size = 0 + ) + + renderer = PointsRenderer( + rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings), + compositor=AlphaCompositor() + ) + + render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer} + + return render_setup + + +def render_pcd(pts3d, imgs, masks, views, renderer, device, nbv=False): + imgs = to_numpy(imgs) + pts3d = to_numpy(pts3d) + + if masks is None: + pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) + col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) + else: + pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) + col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) + + point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) + images = renderer(point_cloud) + + if nbv: + color_mask = torch.ones(col.shape).to(device) + point_cloud_mask = Pointclouds(points=[pts], features=[color_mask]).extend(views) + view_masks = renderer(point_cloud_mask) + else: + view_masks = None + + return images, view_masks + +def run_render(pcd, imgs, masks, H, W, camera_traj, num_views, device, nbv=True): + render_setup = setup_renderer(camera_traj, image_size=(H,W)) + renderer = render_setup['renderer'] + render_results, viewmask = render_pcd(pcd, imgs, masks, num_views, renderer, device, nbv=nbv) + return render_results, viewmask + + +# ========================= +# Prompt parsing +# ========================= +def generate_all_motions_from_prompt(prompt, num_frames): + + x, y, z, phi, theta = parse_prompt_to_motion(prompt) + + results = [] + + for i in range(num_frames): + alpha = i / (num_frames - 1) + + results.append(( + x * alpha, + y * alpha, + z * alpha, + phi * alpha, + theta * alpha + )) + + return results + + +def parse_prompt_to_motion(prompt): + prompt = prompt.lower() + x = y = z = phi = theta = 0.0 + + clauses = re.split(r'[;,\n]| and ', prompt) + + for clause in clauses: + + nums = re.findall(r"[-+]?\d*\.?\d+", clause) + + if not nums: + continue + + val = float(nums[0]) + + if "pans left" in clause: + phi = -val + elif "pans right" in clause: + phi = val + elif "tilts up" in clause: + theta = val + elif "tilts down" in clause: + theta = -val + elif "moves forward" in clause: + z = val + elif "moves backward" in clause: + z = -val + elif "moves up" in clause: + y = -val + elif "moves down" in clause: + y = val + elif "moves left" in clause: + x = -val + elif "moves right" in clause: + x = val + + return x, y, z, phi, theta + + +def build_estimate_rel(x, y, z, phi, theta): + + delta_euler = [theta, phi, 0.0] + rot_mat = Rotation.from_euler('xyz', delta_euler, degrees=True).as_matrix() + + mat = np.eye(4) + mat[:3, :3] = rot_mat + mat[:3, 3] = [x, y, z] + return mat + + +# ========================= +# Main inference +# ========================= + +@spaces.GPU +def infer(image, prompt, seed): + + load_models() + + + img = image.convert("RGB") + + TARGET_H, TARGET_W = img.size[1], img.size[0] + TARGET_H = TARGET_H // 32 * 32 + TARGET_W = TARGET_W // 32 * 32 + + img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC) + + all_steps = generate_all_motions_from_prompt(prompt, num_frames=81) + + cam_idx = list(range(81)) + traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx] + + first_frame = [img, img] + first_frame = load_and_preprocess_images(first_frame) + first_frame = first_frame.to(device) + + + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=dtype): + predictions = vggt_model(first_frame) + + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:]) + + first_frame_world_points = predictions["world_points"][0][0] + + focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device) + principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device) + + raw_image = first_frame[0].cpu().numpy() + raw_image = raw_image.transpose(1, 2, 0) + + render_results_list = [] + + + for estimate_rel in traj: + estimate_rel = torch.from_numpy(estimate_rel).float().to(device) + relative_c2ws = estimate_rel.unsqueeze(0) + R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:] + R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2) + new_c2w = torch.cat([R, T], 2) + + w2c = torch.linalg.inv(torch.cat( + (new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)), + 1 + )) + R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3] + + + image_size = (first_frame.shape[-2:],) + + cameras = PerspectiveCameras( + focal_length=focals, + principal_point=principal_points, + in_ndc=False, + image_size=image_size, + R=R_new, + T=T_new, + device=device + ) + + masks = None + render_results, viewmask = run_render( + [first_frame_world_points], + [raw_image], + masks, + image_size[0][0], image_size[0][1], + cameras, + 1, + device=device + ) + + + render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8) + + if len(render_result.shape) == 2: + render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB) + elif render_result.shape[-1] == 4: + render_result = render_result[..., :3] + + render_results_list.append(render_result) + + + raw_image = first_frame[0].cpu().numpy() + raw_image = raw_image.transpose(1, 2, 0) + + raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8) + + render_results_list[0] = raw_image + + frame_indices = np.linspace( + 0, + 80, + 25 + ).round().astype(int) + + frames = [] + for idx in frame_indices: + frame = render_results_list[idx] + frame = Image.fromarray(frame) + frames.append(frame) + + + last = frames[-1] + for _ in range(4): + frames.append(last) + + # TARGET_H, TARGET_W = 704, 1248 + + def resize_pil(img): + return img.resize((TARGET_W, TARGET_H), Image.BICUBIC) + + frames = [resize_pil(f) for f in frames] + image = resize_pil(image) + + # ===== Wan ===== + video = wan_pipe( + prompt="Ensure the consistency of the video", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + src_video=frames, + input_image=image, + height=TARGET_H, + width=TARGET_W, + cfg_scale=5.0, + num_frames=29, + num_inference_steps=50, + seed=int(seed), + tiled=True + ) + + video_frames = list(video) + last_frame = np.array(video_frames[-1]) + + pcd_last = frames[-1] + + return Image.fromarray(last_frame), pcd_last + + +# ========================= +# Gradio UI +# ========================= +with gr.Blocks() as demo: + + # ===== 标题 + 说明 ===== + gr.Markdown(""" +
+ +UniGeo: Unifying Geometric Guidance for Camera-Controllable Image Editing via Video Models
+ +
+ +Input Requirement / 输入要求
+The input image is recommended to have width ≥ height due to VGGT and Wan model constraints.
+由于 VGGT 与 Wan 模型限制,建议输入图像满足 宽 ≥ 高。 + +
+ +Usage Guide / 使用说明
+You can input one or multiple camera commands separated by semicolons, such as “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”. The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.
+支持输入一条或多条相机控制指令(使用分号分隔),例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”。所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。 + +
+""") + + # ===== 输入输出图 ===== + with gr.Row(): + inp = gr.Image(type="pil", label="Input Image") + out = gr.Image(type="numpy", label="Output Image") + + # ===== prompt + seed ===== + with gr.Row(): + txt = gr.Textbox(label="Camera Prompt") + seed_inp = gr.Number(value=0, label="Seed", precision=0) + + run_btn = gr.Button("Run") + + # ===== 点云输出 ===== + pcd_out = gr.Image(type="pil", label="Final Frame Point Cloud") + + # ===== 绑定 ===== + run_btn.click( + fn=infer, + inputs=[inp, txt, seed_inp], + outputs=[out, pcd_out] + ) + +if __name__ == "__main__": + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860 + ) \ No newline at end of file diff --git a/my_config.json b/my_config.json new file mode 100644 index 0000000000000000000000000000000000000000..66ed5c1e32e6d535005bb55a61ceb8f9d5bdb672 --- /dev/null +++ b/my_config.json @@ -0,0 +1,17 @@ +{ + "has_image_input": false, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-06, + "seperated_timestep": true, + "require_clip_embedding": false, + "require_vae_embedding": false, + "fuse_vae_embedding_in_latents": true +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd5ed8653076ba5ed88bfeebe3ff41144c50db7a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +diffusers==0.35.1 +einops==0.6.1 +huggingface-hub==0.36.0 +imageio==2.27.0 +imageio-ffmpeg==0.4.8 +numpy==1.23.5 +opencv-python==4.7.0.72 +peft==0.17.0 +protobuf==3.20.3 +scikit-image==0.20.0 +scikit-learn==1.2.2 +scipy==1.9.1 +sentencepiece==0.2.1 +tokenizers==0.21.1 +torch==2.4.1 +torchvision==0.19.1 +tqdm==4.67.3 +transformers==4.52.3 +ftfy==6.3.1 +gradio==4.44.1 +gradio_client==1.3.0 +pydantic==2.10.6 \ No newline at end of file diff --git a/vggt/CODE_OF_CONDUCT.md b/vggt/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..3232ed665566ec047ce55a929db1581dbda266a1 --- /dev/null +++ b/vggt/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/vggt/CONTRIBUTING.md b/vggt/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..72baaa2eb86da6050a43c1ea553c095932a5b939 --- /dev/null +++ b/vggt/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to vggt +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to vggt, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/vggt/LICENSE.txt b/vggt/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..079cb728fea5b11bb46c0c9f6ad4b3d2e4d216c7 --- /dev/null +++ b/vggt/LICENSE.txt @@ -0,0 +1,115 @@ +VGGT License + +v1 Last Updated: July 29, 2025 + +“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement. + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein. + + +“Documentation” means the specifications, manuals and documentation accompanying +Research Materials distributed by Meta. + + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). +“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement. + +By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement. + + +1. License Rights and Redistribution. + + +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials. + +b. Redistribution and Use. + + +i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party. + + +ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication. + + +iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement. +2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. + + +a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + +b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + + +8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. + + +Acceptable Use Policy + +Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all. + +As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials. + +Prohibited Uses + +You agree you will not use, or allow others to use, Research Materials to: + + Violate the law or others’ rights, including to: +Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: +Violence or terrorism +Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material +Human trafficking, exploitation, and sexual violence +The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. +Sexual solicitation +Any other criminal activity + +Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals + +Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services + +Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices + +Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws + +Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials + +Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system + +2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following: + +Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State + +Guns and illegal weapons (including weapon development) + +Illegal drugs and regulated/controlled substances +Operation of critical infrastructure, transportation technologies, or heavy machinery + +Self-harm or harm to others, including suicide, cutting, and eating disorders +Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual + +3. Intentionally deceive or mislead others, including use of Research Materials related to the following: + + Generating, promoting, or furthering fraud or the creation or promotion of disinformation + Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content + +Generating, promoting, or further distributing spam + + Impersonating another individual without consent, authorization, or legal right + +Representing that outputs of research materials or outputs from technology using Research Materials are human-generated + +Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement + +4. Fail to appropriately disclose to end users any known dangers of your Research Materials. diff --git a/vggt/__pycache__/utils.cpython-39.pyc b/vggt/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79142680b02839bf44d8decb5d03db119c6fb8f7 Binary files /dev/null and b/vggt/__pycache__/utils.cpython-39.pyc differ diff --git a/vggt/docs/package.md b/vggt/docs/package.md new file mode 100644 index 0000000000000000000000000000000000000000..356df89b613f9b48dd47d8b993bf792715237a6b --- /dev/null +++ b/vggt/docs/package.md @@ -0,0 +1,45 @@ +# Alternative Installation Methods + +This document explains how to install VGGT as a package using different package managers. + +## Prerequisites + +Before installing VGGT as a package, you need to install PyTorch and torchvision. We don't list these as dependencies to avoid CUDA version mismatches. Install them first, with an example as: + +```bash +# install pytorch 2.3.1 with cuda 12.1 +pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 +``` + +## Installation Options + +### Install with pip + +The simplest way to install VGGT is using pip: + +```bash +pip install -e . +``` + +### Install and run with pixi + +[Pixi](https://pixi.sh) is a package management tool for creating reproducible environments. + +1. First, [download and install pixi](https://pixi.sh/latest/get_started/) +2. Then run: + +```bash +pixi run -e python demo_gradio.py +``` + +### Install and run with uv + +[uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver. + +1. First, [install uv](https://docs.astral.sh/uv/getting-started/installation/) +2. Then run: + +```bash +uv run --extra demo demo_gradio.py +``` + diff --git a/vggt/pyproject.toml b/vggt/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..81d4f1de65b9218aaf9f2c8c6a3596c9cfd19d48 --- /dev/null +++ b/vggt/pyproject.toml @@ -0,0 +1,52 @@ +[project] +authors = [{name = "Jianyuan Wang", email = "jianyuan@robots.ox.ac.uk"}] +dependencies = [ + "numpy<2", + "Pillow", + "huggingface_hub", + "einops", + "safetensors", + "opencv-python", +] +name = "vggt" +requires-python = ">= 3.10" +version = "0.0.1" + +[project.optional-dependencies] +demo = [ + "gradio==5.17.1", + "viser==0.2.23", + "tqdm", + "hydra-core", + "omegaconf", + "opencv-python", + "scipy", + "onnxruntime", + "requests", + "trimesh", + "matplotlib", +] + +# Using setuptools as the build backend +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +# setuptools configuration +[tool.setuptools.packages.find] +where = ["."] +include = ["vggt*"] + +# Pixi configuration +[tool.pixi.workspace] +channels = ["conda-forge"] +platforms = ["linux-64"] + +[tool.pixi.pypi-dependencies] +vggt = { path = ".", editable = true } + +[tool.pixi.environments] +default = { solve-group = "default" } +demo = { features = ["demo"], solve-group = "default" } + +[tool.pixi.tasks] diff --git a/vggt/training/README.md b/vggt/training/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e371f3df191058e0f58c5a17572fec58d9a74181 --- /dev/null +++ b/vggt/training/README.md @@ -0,0 +1,132 @@ +# Training + +This is a re-implementation of our framework for training VGGT. This document shows how to set up the environment and run VGGT training. I have aimed to faithfully reproduce the original training framework, but please open an issue if anything looks off. + +## 1. Prerequisites + +Before you begin, ensure you have completed the following steps: + +1. **Install VGGT as a package:** + ```bash + pip install -e . + ``` + +2. **Prepare the dataset and annotations:** + - Download the Co3D dataset from the [official repository](https://github.com/facebookresearch/co3d). + - Download the required annotation files from [Hugging Face](https://huggingface.co/datasets/JianyuanWang/co3d_anno/tree/main). + +## 2. Configuration + +After downloading the dataset and annotations, configure the paths in `training/config/default.yaml`. + +### Required Path Configuration + +1. Open `training/config/default.yaml` +2. Update the following paths with your absolute directory paths: + - `CO3D_DIR`: Path to your Co3D dataset + - `CO3D_ANNOTATION_DIR`: Path to your Co3D annotation files + - `resume_checkpoint_path`: Path to your pre-trained VGGT checkpoint + +### Configuration Example + +```yaml +data: + train: + dataset: + dataset_configs: + - _target_: data.datasets.co3d.Co3dDataset + split: train + CO3D_DIR: /YOUR/PATH/TO/CO3D + CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION +# ... same for val ... + +checkpoint: + resume_checkpoint_path: /YOUR/PATH/TO/CKPT +``` + +## 3. Fine-tuning on Co3D + +To fine-tune the provided pre-trained model on the Co3D dataset, run the following command. This example uses 4 GPUs with PyTorch Distributed Data Parallel (DDP): + +```bash +torchrun --nproc_per_node=4 launch.py +``` + +The default configuration in `training/config/default.yaml` is set up for fine-tuning. It automatically resumes from a checkpoint and freezes the model's `aggregator` module during training. + +## 4. Training on Multiple Datasets + +The dataloader supports multiple datasets naturally. For example, if you have downloaded VKitti using `preprocess/vkitti.sh`, you can train on Co3D+VKitti by configuring: + +```yaml +data: + train: + dataset: + _target_: data.composed_dataset.ComposedDataset + dataset_configs: + - _target_: data.datasets.co3d.Co3dDataset + split: train + CO3D_DIR: /YOUR/PATH/TO/CO3D + CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION + len_train: 100000 + - _target_: data.datasets.vkitti.VKittiDataset + split: train + VKitti_DIR: /YOUR/PATH/TO/VKitti + len_train: 100000 + expand_ratio: 8 +``` + +The ratio of different datasets can be controlled by setting `len_train`. For example, Co3D with `len_train: 10000` and VKitti with `len_train: 2000` will result in Co3D being sampled five times more frequently than VKitti. + +## 5. Common Questions + +### Memory Management + +If you encounter out-of-memory (OOM) errors on your GPU, consider adjusting the following parameters in `training/config/default.yaml`: + +- `max_img_per_gpu`: Reduce this value to decrease the batch size per GPU +- `accum_steps`: Sets the number of gradient accumulation steps (default is 2). This feature splits batches into smaller chunks to save memory, though it may slightly increase training time. Note that gradient accumulation was not used for the original VGGT model. + +### Learning Rate Tuning + +The main hyperparameter to be careful about is learning rate. Note that learning rate depends on the effective batch size, which is `batch_size_per_gpu × num_gpus`. Therefore, I highly recommend trying several learning rates based on your training setup. Generally, trying values like `5e-6`, `1e-5`, `5e-5`, `1e-4`, `5e-4` should be sufficient. + +### Tracking Head + +The tracking head can slightly improve accuracy but is not necessary. For general cases, especially when GPU resources are limited, we suggest fine-tuning the pre-trained model only with camera and depth heads, which is the setting in `default.yaml`. This will provide good enough results. + +### Dataloader Validation + +To check if your dataloader is working correctly, the best approach is to visualize its output. You can save the 3D world points as follows and then visually inspect the PLY files: + +```python +def save_ply(points, colors, filename): + import open3d as o3d + if torch.is_tensor(points): + points_visual = points.reshape(-1, 3).cpu().numpy() + else: + points_visual = points.reshape(-1, 3) + if torch.is_tensor(colors): + points_visual_rgb = colors.reshape(-1, 3).cpu().numpy() + else: + points_visual_rgb = colors.reshape(-1, 3) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points_visual.astype(np.float64)) + pcd.colors = o3d.utility.Vector3dVector(points_visual_rgb.astype(np.float64)) + o3d.io.write_point_cloud(filename, pcd, write_ascii=True) + +# Usage example +save_ply( + batch["world_points"][0].reshape(-1, 3), + batch["images"][0].permute(0, 2, 3, 1).reshape(-1, 3), + "debug.ply" +) +``` + +### Handling Unordered Sequences + +For unordered sequences, you can check how we compute the ranking (similarity) between one frame and all other frames, as discussed in [Issue #82](https://github.com/facebookresearch/vggt/issues/82). + +### Expected Coordinate System + +Camera poses are expected to follow the OpenCV `camera-from-world` convention. Depth maps should be aligned with their corresponding camera poses. diff --git a/vggt/training/__init__.py b/vggt/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/training/config/default.yaml b/vggt/training/config/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8bec7a7361c36065ff4dc34518f7eac53026f303 --- /dev/null +++ b/vggt/training/config/default.yaml @@ -0,0 +1,183 @@ +defaults: + - default_dataset.yaml + +exp_name: exp001 +img_size: 518 +num_workers: 8 +seed_value: 42 +accum_steps: 2 # We did not use gradient accumulation in our training, while if you suffer from OOM, you can try to use it. +patch_size: 14 +val_epoch_freq: 5 +max_img_per_gpu: 48 + +limit_train_batches: 800 +limit_val_batches: 400 + + +data: + # The code for data still looks too complicated. I should refactor this again (do I have time?...) + train: + _target_: data.dynamic_dataloader.DynamicTorchDataset + num_workers: ${num_workers} + max_img_per_gpu: ${max_img_per_gpu} + common_config: + img_size: ${img_size} + patch_size: ${patch_size} + debug: False + repeat_batch: False + dataset: + _target_: data.composed_dataset.ComposedDataset + dataset_configs: + - _target_: data.datasets.co3d.Co3dDataset + split: train + CO3D_DIR: /YOUR/PATH/TO/CO3D + CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION + val: + _target_: data.dynamic_dataloader.DynamicTorchDataset + num_workers: ${num_workers} + max_img_per_gpu: ${max_img_per_gpu} + common_config: + img_size: ${img_size} + patch_size: ${patch_size} + debug: False + dataset: + _target_: data.composed_dataset.ComposedDataset + dataset_configs: + - _target_: data.datasets.co3d.Co3dDataset + split: test + CO3D_DIR: /YOUR/PATH/TO/CO3D + CO3D_ANNOTATION_DIR: /YOUR/PATH/TO/CO3D_ANNOTATION + + +logging: + log_dir: logs + log_visuals: False + log_freq: 1 + log_level_primary: DEBUG + log_level_secondary: WARNING + all_ranks: False + tensorboard_writer: + _target_: train_utils.tb_writer.TensorBoardLogger + path: ${logging.log_dir}/tensorboard + scalar_keys_to_log: + train: + keys_to_log: + - loss_objective + - loss_camera + - loss_T + - loss_R + - loss_FL + - loss_conf_depth + - loss_reg_depth + - loss_grad_depth + val: + keys_to_log: + - loss_objective + - loss_camera + - loss_T + - loss_R + - loss_FL + - loss_conf_depth + - loss_reg_depth + - loss_grad_depth + + + +checkpoint: + save_dir: logs/${exp_name}/ckpts + save_freq: 5 + resume_checkpoint_path: /YOUR/PATH/TO/CKPT + strict: False + + +loss: + _target_: loss.MultitaskLoss + camera: + weight: 5.0 + loss_type: "l1" # The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss. + depth: + weight: 1.0 + gradient_loss_fn: "grad" + valid_range: 0.98 + point: null + # If you want to enable point, use the following config + # point: + # weight: 1.0 + # gradient_loss_fn: "normal" + # valid_range: 0.98 + track: null + + + + +optim: + param_group_modifiers: False + + optimizer: + _target_: torch.optim.AdamW + lr: 5e-5 + weight_decay: 0.05 + + frozen_module_names: + - "*aggregator*" # example, freeze the aggregator + + amp: + enabled: True + amp_dtype: bfloat16 + gradient_clip: + _target_: train_utils.gradient_clip.GradientClipper + configs: + - module_name: ["aggregator"] + max_norm: 1.0 # feel free to reduce this if you see instabilities + norm_type: 2 + - module_name: ["depth"] + max_norm: 1.0 # feel free to reduce this if you see instabilities + norm_type: 2 + - module_name: ["camera"] + max_norm: 1.0 # feel free to reduce this if you see instabilities + norm_type: 2 + options: + lr: + - scheduler: + _target_: fvcore.common.param_scheduler.CompositeParamScheduler + schedulers: + - _target_: fvcore.common.param_scheduler.LinearParamScheduler + start_value: 1e-8 + end_value: 5e-5 + - _target_: fvcore.common.param_scheduler.CosineParamScheduler + start_value: 5e-5 + end_value: 1e-8 + lengths: [0.05, 0.95] + interval_scaling: ['rescaled', 'rescaled'] + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.05 + + + + +max_epochs: 20 + +model: + _target_: vggt.models.vggt.VGGT + enable_camera: True + enable_depth: True + enable_point: False + enable_track: False + + +distributed: + # check https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html for options + backend: nccl + comms_dtype: None + find_unused_parameters: False + timeout_mins: 30 + gradient_as_bucket_view: True # Less memory used + bucket_cap_mb: 25 + broadcast_buffers: True + +cuda: + cudnn_deterministic: False + cudnn_benchmark: False + allow_tf32: True diff --git a/vggt/training/config/default_dataset.yaml b/vggt/training/config/default_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0390dfa369021caf7fdf6da35a49a7f0d0f2a2c3 --- /dev/null +++ b/vggt/training/config/default_dataset.yaml @@ -0,0 +1,80 @@ +# Template for the dataset config +data: + # The code still looks too complicated. I should refactor this again (do I have time?...) + train: + _target_: data.dynamic_dataloader.DynamicTorchDataset + num_workers: 8 + max_img_per_gpu: 48 + # Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory + # (see: https://github.com/pytorch/pytorch/issues/13246). + # To avoid this, set shuffle=False and enable common_config.inside_random=True instead. + shuffle: True + pin_memory: False + common_config: # common config for evaluation + fix_img_num: -1 # -1 means do not fix the number of images + fix_aspect_ratio: 1.0 + load_track: False + track_num: 1024 + training: True + inside_random: True + img_size: 224 + patch_size: 14 + rescale: True + rescale_aug: True + landscape_check: False + debug: False + get_nearby: True + load_depth: True + img_nums: [2, 24] + max_img_per_gpu: 48 + allow_duplicate_img: True + repeat_batch: False + + augs: + cojitter: True + cojitter_ratio: 0.3 + scales: [0.8, 1.2] + aspects: [0.33, 1.0] + color_jitter: + brightness: 0.5 + contrast: 0.5 + saturation: 0.5 + hue: 0.1 + p: 0.9 + gray_scale: True + gau_blur: False + val: + _target_: data.dynamic_dataloader.DynamicTorchDataset + num_workers: 8 + max_img_per_gpu: 48 + # Shuffling in PyTorch DataLoader can sometimes copy large dicts and exceed CPU memory + # (see: https://github.com/pytorch/pytorch/issues/13246). + # To avoid this, set shuffle=False and enable common_config.inside_random=True instead. + shuffle: True + pin_memory: False + common_config: # common config for evaluation + fix_img_num: -1 # -1 means do not fix the number of images + fix_aspect_ratio: 1.0 + load_track: False + track_num: 1024 + training: False + inside_random: True + img_size: 224 + patch_size: 14 + rescale: True + rescale_aug: False + landscape_check: False + debug: False + get_nearby: True + load_depth: True + img_nums: [2, 12] + allow_duplicate_img: True + + augs: + cojitter: False + cojitter_ratio: 0.5 + scales: null + aspects: [1.0, 1.0] + color_jitter: null + gray_scale: False + gau_blur: False \ No newline at end of file diff --git a/vggt/training/data/__init__.py b/vggt/training/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/training/data/augmentation.py b/vggt/training/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..6eef99cc65c69e349eb2982bc90fbf98808968bc --- /dev/null +++ b/vggt/training/data/augmentation.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Dict +from torchvision import transforms + + +def get_image_augmentation( + color_jitter: Optional[Dict[str, float]] = None, + gray_scale: bool = True, + gau_blur: bool = False +) -> Optional[transforms.Compose]: + """Create a composition of image augmentations. + + Args: + color_jitter: Dictionary containing color jitter parameters: + - brightness: float (default: 0.5) + - contrast: float (default: 0.5) + - saturation: float (default: 0.5) + - hue: float (default: 0.1) + - p: probability of applying (default: 0.9) + If None, uses default values + gray_scale: Whether to apply random grayscale (default: True) + gau_blur: Whether to apply gaussian blur (default: False) + + Returns: + A Compose object of transforms or None if no transforms are added + """ + transform_list = [] + default_jitter = { + "brightness": 0.5, + "contrast": 0.5, + "saturation": 0.5, + "hue": 0.1, + "p": 0.9 + } + + # Handle color jitter + if color_jitter is not None: + # Merge with defaults for missing keys + effective_jitter = {**default_jitter, **color_jitter} + else: + effective_jitter = default_jitter + + transform_list.append( + transforms.RandomApply( + [ + transforms.ColorJitter( + brightness=effective_jitter["brightness"], + contrast=effective_jitter["contrast"], + saturation=effective_jitter["saturation"], + hue=effective_jitter["hue"], + ) + ], + p=effective_jitter["p"], + ) + ) + + if gray_scale: + transform_list.append(transforms.RandomGrayscale(p=0.05)) + + if gau_blur: + transform_list.append( + transforms.RandomApply( + [transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05 + ) + ) + + return transforms.Compose(transform_list) if transform_list else None diff --git a/vggt/training/data/base_dataset.py b/vggt/training/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..28b0d3f3195eb8a122504a56c10b94766c8a75d2 --- /dev/null +++ b/vggt/training/data/base_dataset.py @@ -0,0 +1,303 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from PIL import Image, ImageFile + +from torch.utils.data import Dataset +from .dataset_util import * + +Image.MAX_IMAGE_PIXELS = None +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class BaseDataset(Dataset): + """ + Base dataset class for VGGT and VGGSfM training. + + This abstract class handles common operations like image resizing, + augmentation, and coordinate transformations. Concrete dataset + implementations should inherit from this class. + + Attributes: + img_size: Target image size (typically the width) + patch_size: Size of patches for vit + augs.scales: Scale range for data augmentation [min, max] + rescale: Whether to rescale images + rescale_aug: Whether to apply augmentation during rescaling + landscape_check: Whether to handle landscape vs portrait orientation + """ + def __init__( + self, + common_conf, + ): + """ + Initialize the base dataset with common configuration. + + Args: + common_conf: Configuration object with the following properties, shared by all datasets: + - img_size: Default is 518 + - patch_size: Default is 14 + - augs.scales: Default is [0.8, 1.2] + - rescale: Default is True + - rescale_aug: Default is True + - landscape_check: Default is True + """ + super().__init__() + self.img_size = common_conf.img_size + self.patch_size = common_conf.patch_size + self.aug_scale = common_conf.augs.scales + self.rescale = common_conf.rescale + self.rescale_aug = common_conf.rescale_aug + self.landscape_check = common_conf.landscape_check + + def __len__(self): + return self.len_train + + def __getitem__(self, idx_N): + """ + Get an item from the dataset. + + Args: + idx_N: Tuple containing (seq_index, img_per_seq, aspect_ratio) + + Returns: + Dataset item as returned by get_data() + """ + seq_index, img_per_seq, aspect_ratio = idx_N + return self.get_data( + seq_index=seq_index, img_per_seq=img_per_seq, aspect_ratio=aspect_ratio + ) + + def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_ratio=1.0): + """ + Abstract method to retrieve data for a given sequence. + + Args: + seq_index (int, optional): Index of the sequence + seq_name (str, optional): Name of the sequence + ids (list, optional): List of frame IDs + aspect_ratio (float, optional): Target aspect ratio. + + Returns: + Dataset-specific data + + Raises: + NotImplementedError: This method must be implemented by subclasses + """ + raise NotImplementedError( + "This is an abstract method and should be implemented in the subclass, i.e., each dataset should implement its own get_data method." + ) + + def get_target_shape(self, aspect_ratio): + """ + Calculate the target shape based on the given aspect ratio. + + Args: + aspect_ratio: Target aspect ratio + + Returns: + numpy.ndarray: Target image shape [height, width] + """ + short_size = int(self.img_size * aspect_ratio) + small_size = self.patch_size + + # ensure the input shape is friendly to vision transformer + if short_size % small_size != 0: + short_size = (short_size // small_size) * small_size + + image_shape = np.array([short_size, self.img_size]) + return image_shape + + def process_one_image( + self, + image, + depth_map, + extri_opencv, + intri_opencv, + original_size, + target_image_shape, + track=None, + filepath=None, + safe_bound=4, + ): + """ + Process a single image and its associated data. + + This method handles image transformations, depth processing, and coordinate conversions. + + Args: + image (numpy.ndarray): Input image array + depth_map (numpy.ndarray): Depth map array + extri_opencv (numpy.ndarray): Extrinsic camera matrix (OpenCV convention) + intri_opencv (numpy.ndarray): Intrinsic camera matrix (OpenCV convention) + original_size (numpy.ndarray): Original image size [height, width] + target_image_shape (numpy.ndarray): Target image shape after processing + track (numpy.ndarray, optional): Optional tracking information. Defaults to None. + filepath (str, optional): Optional file path for debugging. Defaults to None. + safe_bound (int, optional): Safety margin for cropping operations. Defaults to 4. + + Returns: + tuple: ( + image (numpy.ndarray): Processed image, + depth_map (numpy.ndarray): Processed depth map, + extri_opencv (numpy.ndarray): Updated extrinsic matrix, + intri_opencv (numpy.ndarray): Updated intrinsic matrix, + world_coords_points (numpy.ndarray): 3D points in world coordinates, + cam_coords_points (numpy.ndarray): 3D points in camera coordinates, + point_mask (numpy.ndarray): Boolean mask of valid points, + track (numpy.ndarray, optional): Updated tracking information + ) + """ + # Make copies to avoid in-place operations affecting original data + image = np.copy(image) + depth_map = np.copy(depth_map) + extri_opencv = np.copy(extri_opencv) + intri_opencv = np.copy(intri_opencv) + if track is not None: + track = np.copy(track) + + # Apply random scale augmentation during training if enabled + if self.training and self.aug_scale: + random_h_scale, random_w_scale = np.random.uniform( + self.aug_scale[0], self.aug_scale[1], 2 + ) + # Avoid random padding by capping at 1.0 + random_h_scale = min(random_h_scale, 1.0) + random_w_scale = min(random_w_scale, 1.0) + aug_size = original_size * np.array([random_h_scale, random_w_scale]) + aug_size = aug_size.astype(np.int32) + else: + aug_size = original_size + + # Move principal point to the image center and crop if necessary + image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( + image, depth_map, intri_opencv, aug_size, track=track, filepath=filepath, + ) + + original_size = np.array(image.shape[:2]) # update original_size + target_shape = target_image_shape + + # Handle landscape vs. portrait orientation + rotate_to_portrait = False + if self.landscape_check: + # Switch between landscape and portrait if necessary + if original_size[0] > 1.25 * original_size[1]: + if (target_image_shape[0] != target_image_shape[1]) and (np.random.rand() > 0.5): + target_shape = np.array([target_image_shape[1], target_image_shape[0]]) + rotate_to_portrait = True + + # Resize images and update intrinsics + if self.rescale: + image, depth_map, intri_opencv, track = resize_image_depth_and_intrinsic( + image, depth_map, intri_opencv, target_shape, original_size, track=track, + safe_bound=safe_bound, + rescale_aug=self.rescale_aug + ) + else: + print("Not rescaling the images") + + # Ensure final crop to target shape + image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( + image, depth_map, intri_opencv, target_shape, track=track, filepath=filepath, strict=True, + ) + + # Apply 90-degree rotation if needed + if rotate_to_portrait: + assert self.landscape_check + clockwise = np.random.rand() > 0.5 + image, depth_map, extri_opencv, intri_opencv, track = rotate_90_degrees( + image, + depth_map, + extri_opencv, + intri_opencv, + clockwise=clockwise, + track=track, + ) + + # Convert depth to world and camera coordinates + world_coords_points, cam_coords_points, point_mask = ( + depth_to_world_coords_points(depth_map, extri_opencv, intri_opencv) + ) + + return ( + image, + depth_map, + extri_opencv, + intri_opencv, + world_coords_points, + cam_coords_points, + point_mask, + track, + ) + + def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_range=None): + """ + TODO: add the function to sample the ids by pose similarity ranking. + + Sample a set of IDs from a sequence close to a given start index. + + You can specify the range either as a ratio of the number of input IDs + or as a fixed integer window. + + + Args: + ids (list): Initial list of IDs. The first element is used as the anchor. + full_seq_num (int): Total number of items in the full sequence. + expand_ratio (float, optional): Factor by which the number of IDs expands + around the start index. Default is 2.0 if neither expand_ratio nor + expand_range is provided. + expand_range (int, optional): Fixed number of items to expand around the + start index. If provided, expand_ratio is ignored. + + Returns: + numpy.ndarray: Array of sampled IDs, with the first element being the + original start index. + + Examples: + # Using expand_ratio (default behavior) + # If ids=[100,101,102] and full_seq_num=200, with expand_ratio=2.0, + # expand_range = int(3 * 2.0) = 6, so IDs sampled from [94...106] (if boundaries allow). + + # Using expand_range directly + # If ids=[100,101,102] and full_seq_num=200, with expand_range=10, + # IDs are sampled from [90...110] (if boundaries allow). + + Raises: + ValueError: If no IDs are provided. + """ + if len(ids) == 0: + raise ValueError("No IDs provided.") + + if expand_range is None and expand_ratio is None: + expand_ratio = 2.0 # Default behavior + + total_ids = len(ids) + start_idx = ids[0] + + # Determine the actual expand_range + if expand_range is None: + # Use ratio to determine range + expand_range = int(total_ids * expand_ratio) + + # Calculate valid boundaries + low_bound = max(0, start_idx - expand_range) + high_bound = min(full_seq_num, start_idx + expand_range) + + # Create the valid range of indices + valid_range = np.arange(low_bound, high_bound) + + # Sample 'total_ids - 1' items, because we already have the start_idx + sampled_ids = np.random.choice( + valid_range, + size=(total_ids - 1), + replace=True, # we accept the situation that some sampled ids are the same + ) + + # Insert the start_idx at the beginning + result_ids = np.insert(sampled_ids, 0, start_idx) + + return result_ids diff --git a/vggt/training/data/composed_dataset.py b/vggt/training/data/composed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c327757f90d3178c13b90d8fdb65caf56a382f55 --- /dev/null +++ b/vggt/training/data/composed_dataset.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC + +from hydra.utils import instantiate +import torch +import random +import numpy as np +from torch.utils.data import Dataset +from torch.utils.data import ConcatDataset +import bisect +from .dataset_util import * +from .track_util import * +from .augmentation import get_image_augmentation + + +class ComposedDataset(Dataset, ABC): + """ + Composes multiple base datasets and applies common configurations. + + This dataset provides a flexible way to combine multiple base datasets while + applying shared augmentations, track generation, and other processing steps. + It handles image normalization, tensor conversion, and other preparations + needed for training computer vision models with sequences of images. + """ + def __init__(self, dataset_configs: dict, common_config: dict, **kwargs): + """ + Initializes the ComposedDataset. + + Args: + dataset_configs (dict): List of Hydra configurations for base datasets. + common_config (dict): Shared configurations (augs, tracks, mode, etc.). + **kwargs: Additional arguments (unused). + """ + base_dataset_list = [] + + # Instantiate each base dataset with common configuration + for baseset_dict in dataset_configs: + baseset = instantiate(baseset_dict, common_conf=common_config) + base_dataset_list.append(baseset) + + # Use custom concatenation class that supports tuple indexing + self.base_dataset = TupleConcatDataset(base_dataset_list, common_config) + + # --- Augmentation Settings --- + # Controls whether to apply identical color jittering across all frames in a sequence + self.cojitter = common_config.augs.cojitter + # Probability of using shared jitter vs. frame-specific jitter + self.cojitter_ratio = common_config.augs.cojitter_ratio + # Initialize image augmentations (color jitter, grayscale, gaussian blur) + self.image_aug = get_image_augmentation( + color_jitter=common_config.augs.color_jitter, + gray_scale=common_config.augs.gray_scale, + gau_blur=common_config.augs.gau_blur, + ) + + # --- Optional Fixed Settings (useful for debugging) --- + # Force each sequence to have exactly this many images (if > 0) + self.fixed_num_images = common_config.fix_img_num + # Force a specific aspect ratio for all images + self.fixed_aspect_ratio = common_config.fix_aspect_ratio + + # --- Track Settings --- + # Whether to include point tracks in the output + self.load_track = common_config.load_track + # Number of point tracks to include per sequence + self.track_num = common_config.track_num + + # --- Mode Settings --- + # Whether the dataset is being used for training (affects augmentations) + self.training = common_config.training + self.common_config = common_config + + self.total_samples = len(self.base_dataset) + + def __len__(self): + """Returns the total number of sequences in the dataset.""" + return self.total_samples + + + def __getitem__(self, idx_tuple): + """ + Retrieves a data sample (sequence) from the dataset. + + Loads raw data, converts to PyTorch tensors, applies augmentations, + and prepares tracks if enabled. + + Args: + idx_tuple (tuple): a tuple of (seq_idx, num_images, aspect_ratio) + + Returns: + dict: A dictionary containing the sequence data (images, poses, tracks, etc.). + """ + # If fixed settings are provided, override the tuple values + if self.fixed_num_images > 0: + seq_idx = idx_tuple[0] if isinstance(idx_tuple, tuple) else idx_tuple + idx_tuple = (seq_idx, self.fixed_num_images, self.fixed_aspect_ratio) + + # Retrieve the raw data batch from the appropriate base dataset + batch = self.base_dataset[idx_tuple] + seq_name = batch["seq_name"] + + # --- Data Conversion and Preparation --- + # Convert numpy arrays to tensors + images = torch.from_numpy(np.stack(batch["images"]).astype(np.float32)).contiguous() + # Normalize images from [0, 255] to [0, 1] + images = images.permute(0,3,1,2).to(torch.get_default_dtype()).div(255) + + # Convert other data to tensors with appropriate types + depths = torch.from_numpy(np.stack(batch["depths"]).astype(np.float32)) + extrinsics = torch.from_numpy(np.stack(batch["extrinsics"]).astype(np.float32)) + intrinsics = torch.from_numpy(np.stack(batch["intrinsics"]).astype(np.float32)) + cam_points = torch.from_numpy(np.stack(batch["cam_points"]).astype(np.float32)) + world_points = torch.from_numpy(np.stack(batch["world_points"]).astype(np.float32)) + point_masks = torch.from_numpy(np.stack(batch["point_masks"])) # Mask indicating valid depths / world points / cam points per frame + ids = torch.from_numpy(batch["ids"]) # Frame indices sampled from the original sequence + + + # --- Apply Color Augmentation (training mode only) --- + if self.training and self.image_aug is not None: + if self.cojitter and random.random() > self.cojitter_ratio: + # Apply the same color jittering transformation to all frames + images = self.image_aug(images) + else: + # Apply different color jittering to each frame individually + for aug_img_idx in range(len(images)): + images[aug_img_idx] = self.image_aug(images[aug_img_idx]) + + + # --- Prepare Final Sample Dictionary --- + sample = { + "seq_name": seq_name, + "ids": ids, + "images": images, + "depths": depths, + "extrinsics": extrinsics, + "intrinsics": intrinsics, + "cam_points": cam_points, + "world_points": world_points, + "point_masks": point_masks, + } + + # --- Track Processing (if enabled) --- + if self.load_track: + if batch["tracks"] is not None: + # Use pre-computed tracks from the dataset + tracks = torch.from_numpy(np.stack(batch["tracks"]).astype(np.float32)) + track_vis_mask = torch.from_numpy(np.stack(batch["track_masks"]).astype(bool)) + + # Sample a subset of tracks randomly + valid_indices = torch.where(track_vis_mask[0])[0] + if len(valid_indices) >= self.track_num: + # If we have enough tracks, sample without replacement + sampled_indices = valid_indices[torch.randperm(len(valid_indices))][:self.track_num] + else: + # If not enough tracks, sample with replacement (allow duplicates) + sampled_indices = valid_indices[torch.randint(0, len(valid_indices), + (self.track_num,), + dtype=torch.int64, + device=valid_indices.device)] + + # Extract the sampled tracks and their masks + tracks = tracks[:, sampled_indices, :] + track_vis_mask = track_vis_mask[:, sampled_indices] + track_positive_mask = torch.ones(track_vis_mask.shape[1]).bool() + + else: + # Generate tracks on-the-fly using depth information + # This creates synthetic tracks based on the 3D information available + tracks, track_vis_mask, track_positive_mask = build_tracks_by_depth( + extrinsics, intrinsics, world_points, depths, point_masks, images, + target_track_num=self.track_num, seq_name=seq_name + ) + + # Add track information to the sample dictionary + sample["tracks"] = tracks + sample["track_vis_mask"] = track_vis_mask + sample["track_positive_mask"] = track_positive_mask + + return sample + + +class TupleConcatDataset(ConcatDataset): + """ + A custom ConcatDataset that supports indexing with a tuple. + + Standard PyTorch ConcatDataset only accepts an integer index. This class extends + that functionality to allow passing a tuple like (sample_idx, num_images, aspect_ratio), + where the first element is used to determine which sample to fetch, and the full + tuple is passed down to the selected dataset's __getitem__ method. + + It also supports an option to randomly sample across all datasets, ignoring the + provided index. This is useful during training when shuffling the entire dataset + might cause memory issues due to duplicating dictionaries. If doing this, you can + set pytorch's dataloader shuffle to False. + """ + def __init__(self, datasets, common_config): + """ + Initialize the TupleConcatDataset. + + Args: + datasets (iterable): An iterable of PyTorch Dataset objects to concatenate. + common_config (dict): Common configuration dict, used to check for random sampling. + """ + super().__init__(datasets) + # If True, ignores the input index and samples randomly across all datasets + # This provides an alternative to dataloader shuffling for large datasets + self.inside_random = common_config.inside_random + + def __getitem__(self, idx): + """ + Retrieves an item using either an integer index or a tuple index. + + Args: + idx (int or tuple): The index. If tuple, the first element is the sequence + index across the concatenated datasets, and the rest are + passed down. If int, it's treated as the sequence index. + + Returns: + The item returned by the underlying dataset's __getitem__ method. + + Raises: + ValueError: If the index is out of range or the tuple doesn't have exactly 3 elements. + """ + idx_tuple = None + if isinstance(idx, tuple): + idx_tuple = idx + idx = idx_tuple[0] # Extract the sequence index + + # Override index with random value if inside_random is enabled + if self.inside_random: + total_len = self.cumulative_sizes[-1] + idx = random.randint(0, total_len - 1) + + # Handle negative indices + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + + # Find which dataset the index belongs to + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + # Create the tuple to pass to the underlying dataset + if len(idx_tuple) == 3: + idx_tuple = (sample_idx,) + idx_tuple[1:] + else: + raise ValueError("Tuple index must have exactly three elements") + + # Pass the modified tuple to the appropriate dataset + return self.datasets[dataset_idx][idx_tuple] diff --git a/vggt/training/data/dataset_util.py b/vggt/training/data/dataset_util.py new file mode 100644 index 0000000000000000000000000000000000000000..542af78fdeb138a2b53862f7c6c48c26ffe92b42 --- /dev/null +++ b/vggt/training/data/dataset_util.py @@ -0,0 +1,711 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 +import math +import numpy as np +from PIL import Image +import PIL +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + +from vggt.utils.geometry import closed_form_inverse_se3 + + + +##################################################################################################################### +def crop_image_depth_and_intrinsic_by_pp( + image, depth_map, intrinsic, target_shape, track=None, filepath=None, strict=False +): + """ + TODO: some names of width and height seem not consistent. Need to check. + + + Crops the given image and depth map around the camera's principal point, as defined by `intrinsic`. + Specifically: + - Ensures that the crop is centered on (cx, cy). + - Optionally pads the image (and depth map) if `strict=True` and the result is smaller than `target_shape`. + - Shifts the camera intrinsic matrix (and `track` if provided) accordingly. + + Args: + image (np.ndarray): + Input image array of shape (H, W, 3). + depth_map (np.ndarray or None): + Depth map array of shape (H, W), or None if not available. + intrinsic (np.ndarray): + Camera intrinsic matrix (3x3). The principal point is assumed to be at (intrinsic[1,2], intrinsic[0,2]). + target_shape (tuple[int, int]): + Desired output shape. + track (np.ndarray or None): + Optional array of shape (N, 2). Interpreted as (x, y) pixel coordinates. Will be shifted after cropping. + filepath (str or None): + An optional file path for debug logging (only used if strict mode triggers warnings). + strict (bool): + If True, will zero-pad to ensure the exact target_shape even if the cropped region is smaller. + + Raises: + AssertionError: + If the input image is smaller than `target_shape`. + ValueError: + If the cropped image is larger than `target_shape` (in strict mode), which should not normally happen. + + Returns: + tuple: + (cropped_image, cropped_depth_map, updated_intrinsic, updated_track) + + - cropped_image (np.ndarray): Cropped (and optionally padded) image. + - cropped_depth_map (np.ndarray or None): Cropped (and optionally padded) depth map. + - updated_intrinsic (np.ndarray): Intrinsic matrix adjusted for the crop. + - updated_track (np.ndarray or None): Track array adjusted for the crop, or None if track was not provided. + """ + original_size = np.array(image.shape) + intrinsic = np.copy(intrinsic) + + if original_size[0] < target_shape[0]: + error_message = ( + f"Width check failed: original width {original_size[0]} " + f"is less than target width {target_shape[0]}." + ) + print(error_message) + raise AssertionError(error_message) + + if original_size[1] < target_shape[1]: + error_message = ( + f"Height check failed: original height {original_size[1]} " + f"is less than target height {target_shape[1]}." + ) + print(error_message) + raise AssertionError(error_message) + + # Identify principal point (cx, cy) from intrinsic + cx = (intrinsic[1, 2]) + cy = (intrinsic[0, 2]) + + # Compute how far we can crop in each direction + if strict: + half_x = min((target_shape[0] / 2), cx) + half_y = min((target_shape[1] / 2), cy) + else: + half_x = min((target_shape[0] / 2), cx, original_size[0] - cx) + half_y = min((target_shape[1] / 2), cy, original_size[1] - cy) + + # Compute starting indices + start_x = math.floor(cx) - math.floor(half_x) + start_y = math.floor(cy) - math.floor(half_y) + + assert start_x >= 0 + assert start_y >= 0 + + # Compute ending indices + if strict: + end_x = start_x + target_shape[0] + end_y = start_y + target_shape[1] + else: + end_x = start_x + 2 * math.floor(half_x) + end_y = start_y + 2 * math.floor(half_y) + + # Perform the crop + image = image[start_x:end_x, start_y:end_y, :] + if depth_map is not None: + depth_map = depth_map[start_x:end_x, start_y:end_y] + + # Shift the principal point in the intrinsic + intrinsic[1, 2] = intrinsic[1, 2] - start_x + intrinsic[0, 2] = intrinsic[0, 2] - start_y + + # Adjust track if provided + if track is not None: + track[:, 1] = track[:, 1] - start_x + track[:, 0] = track[:, 0] - start_y + + # If strict, zero-pad if the new shape is smaller than target_shape + if strict: + if (image.shape[:2] != target_shape).any(): + print(f"{filepath} does not meet the target shape") + current_h, current_w = image.shape[:2] + target_h, target_w = target_shape[0], target_shape[1] + pad_h = target_h - current_h + pad_w = target_w - current_w + if pad_h < 0 or pad_w < 0: + raise ValueError( + f"The cropped image is bigger than the target shape: " + f"cropped=({current_h},{current_w}), " + f"target=({target_h},{target_w})." + ) + image = np.pad( + image, + pad_width=((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=0, + ) + if depth_map is not None: + depth_map = np.pad( + depth_map, + pad_width=((0, pad_h), (0, pad_w)), + mode="constant", + constant_values=0, + ) + + return image, depth_map, intrinsic, track + + +def resize_image_depth_and_intrinsic( + image, + depth_map, + intrinsic, + target_shape, + original_size, + track=None, + pixel_center=True, + safe_bound=4, + rescale_aug=True, +): + """ + Resizes the given image and depth map (if provided) to slightly larger than `target_shape`, + updating the intrinsic matrix (and track array if present). Optionally uses random rescaling + to create some additional margin (based on `rescale_aug`). + + Steps: + 1. Compute a scaling factor so that the resized result is at least `target_shape + safe_bound`. + 2. Apply an optional triangular random factor if `rescale_aug=True`. + 3. Resize the image with LANCZOS if downscaling, BICUBIC if upscaling. + 4. Resize the depth map with nearest-neighbor. + 5. Update the camera intrinsic and track coordinates (if any). + + Args: + image (np.ndarray): + Input image array (H, W, 3). + depth_map (np.ndarray or None): + Depth map array (H, W), or None if unavailable. + intrinsic (np.ndarray): + Camera intrinsic matrix (3x3). + target_shape (np.ndarray or tuple[int, int]): + Desired final shape (height, width). + original_size (np.ndarray or tuple[int, int]): + Original size of the image in (height, width). + track (np.ndarray or None): + Optional (N, 2) array of pixel coordinates. Will be scaled. + pixel_center (bool): + If True, accounts for 0.5 pixel center shift during resizing. + safe_bound (int or float): + Additional margin (in pixels) to add to target_shape before resizing. + rescale_aug (bool): + If True, randomly increase the `safe_bound` within a certain range to simulate augmentation. + + Returns: + tuple: + (resized_image, resized_depth_map, updated_intrinsic, updated_track) + + - resized_image (np.ndarray): The resized image. + - resized_depth_map (np.ndarray or None): The resized depth map. + - updated_intrinsic (np.ndarray): Camera intrinsic updated for new resolution. + - updated_track (np.ndarray or None): Track array updated or None if not provided. + + Raises: + AssertionError: + If the shapes of the resized image and depth map do not match. + """ + if rescale_aug: + random_boundary = np.random.triangular(0, 0, 0.3) + safe_bound = safe_bound + random_boundary * target_shape.max() + + resize_scales = (target_shape + safe_bound) / original_size + max_resize_scale = np.max(resize_scales) + intrinsic = np.copy(intrinsic) + + # Convert image to PIL for resizing + image = Image.fromarray(image) + input_resolution = np.array(image.size) + output_resolution = np.floor(input_resolution * max_resize_scale).astype(int) + image = image.resize(tuple(output_resolution), resample=lanczos if max_resize_scale < 1 else bicubic) + image = np.array(image) + + if depth_map is not None: + depth_map = cv2.resize( + depth_map, + output_resolution, + fx=max_resize_scale, + fy=max_resize_scale, + interpolation=cv2.INTER_NEAREST, + ) + + actual_size = np.array(image.shape[:2]) + actual_resize_scale = np.max(actual_size / original_size) + + if pixel_center: + intrinsic[0, 2] = intrinsic[0, 2] + 0.5 + intrinsic[1, 2] = intrinsic[1, 2] + 0.5 + + intrinsic[:2, :] = intrinsic[:2, :] * actual_resize_scale + + if track is not None: + track = track * actual_resize_scale + + if pixel_center: + intrinsic[0, 2] = intrinsic[0, 2] - 0.5 + intrinsic[1, 2] = intrinsic[1, 2] - 0.5 + + assert image.shape[:2] == depth_map.shape[:2] + return image, depth_map, intrinsic, track + + +def threshold_depth_map( + depth_map: np.ndarray, + max_percentile: float = 99, + min_percentile: float = 1, + max_depth: float = -1, +) -> np.ndarray: + """ + Thresholds a depth map using percentile-based limits and optional maximum depth clamping. + + Steps: + 1. If `max_depth > 0`, clamp all values above `max_depth` to zero. + 2. Compute `max_percentile` and `min_percentile` thresholds using nanpercentile. + 3. Zero out values above/below these thresholds, if thresholds are > 0. + + Args: + depth_map (np.ndarray): + Input depth map (H, W). + max_percentile (float): + Upper percentile (0-100). Values above this will be set to zero. + min_percentile (float): + Lower percentile (0-100). Values below this will be set to zero. + max_depth (float): + Absolute maximum depth. If > 0, any depth above this is set to zero. + If <= 0, no maximum-depth clamp is applied. + + Returns: + np.ndarray: + Depth map (H, W) after thresholding. Some or all values may be zero. + Returns None if depth_map is None. + """ + if depth_map is None: + return None + + depth_map = depth_map.astype(float, copy=True) + + # Optional clamp by max_depth + if max_depth > 0: + depth_map[depth_map > max_depth] = 0.0 + + # Percentile-based thresholds + depth_max_thres = ( + np.nanpercentile(depth_map, max_percentile) if max_percentile > 0 else None + ) + depth_min_thres = ( + np.nanpercentile(depth_map, min_percentile) if min_percentile > 0 else None + ) + + # Apply the thresholds if they are > 0 + if depth_max_thres is not None and depth_max_thres > 0: + depth_map[depth_map > depth_max_thres] = 0.0 + if depth_min_thres is not None and depth_min_thres > 0: + depth_map[depth_map < depth_min_thres] = 0.0 + + return depth_map + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Converts a depth map to world coordinates (HxWx3) given the camera extrinsic and intrinsic. + Returns both the world coordinates and the intermediate camera coordinates, + as well as a mask for valid depth. + + Args: + depth_map (np.ndarray): + Depth map of shape (H, W). + extrinsic (np.ndarray): + Extrinsic matrix of shape (3, 4), representing the camera pose in OpenCV convention (camera-from-world). + intrinsic (np.ndarray): + Intrinsic matrix of shape (3, 3). + eps (float): + Small epsilon for thresholding valid depth. + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray]: + (world_coords_points, cam_coords_points, point_mask) + + - world_coords_points: (H, W, 3) array of 3D points in world frame. + - cam_coords_points: (H, W, 3) array of 3D points in camera frame. + - point_mask: (H, W) boolean array where True indicates valid (non-zero) depth. + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # The extrinsic is camera-from-world, so invert it to transform camera->world + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = ( + np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world + ) # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points( + depth_map: np.ndarray, intrinsic: np.ndarray +) -> np.ndarray: + """ + Unprojects a depth map into camera coordinates, returning (H, W, 3). + + Args: + depth_map (np.ndarray): + Depth map of shape (H, W). + intrinsic (np.ndarray): + 3x3 camera intrinsic matrix. + Assumes zero skew and standard OpenCV layout: + [ fx 0 cx ] + [ 0 fy cy ] + [ 0 0 1 ] + + Returns: + np.ndarray: + An (H, W, 3) array, where each pixel is mapped to (x, y, z) in the camera frame. + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert ( + intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0 + ), "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + return np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + +def rotate_90_degrees( + image, depth_map, extri_opencv, intri_opencv, clockwise=True, track=None +): + """ + Rotates the input image, depth map, and camera parameters by 90 degrees. + + Applies one of two 90-degree rotations: + - Clockwise + - Counterclockwise (if clockwise=False) + + The extrinsic and intrinsic matrices are adjusted accordingly to maintain + correct camera geometry. Track coordinates are also updated if provided. + + Args: + image (np.ndarray): + Input image of shape (H, W, 3). + depth_map (np.ndarray or None): + Depth map of shape (H, W), or None if not available. + extri_opencv (np.ndarray): + Extrinsic matrix (3x4) in OpenCV convention. + intri_opencv (np.ndarray): + Intrinsic matrix (3x3). + clockwise (bool): + If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise. + track (np.ndarray or None): + Optional (N, 2) track array. Will be rotated accordingly. + + Returns: + tuple: + ( + rotated_image, + rotated_depth_map, + new_extri_opencv, + new_intri_opencv, + new_track + ) + + Where each is the updated version after the rotation. + """ + image_height, image_width = image.shape[:2] + + # Rotate the image and depth map + rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise) + # Adjust the intrinsic matrix + new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise) + + if track is not None: + new_track = adjust_track_rot90(track, image_width, image_height, clockwise) + else: + new_track = None + + # Adjust the extrinsic matrix + new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise) + + return ( + rotated_image, + rotated_depth_map, + new_extri_opencv, + new_intri_opencv, + new_track, + ) + + +def rotate_image_and_depth_rot90(image, depth_map, clockwise): + """ + Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise), + using a transpose+flip pattern. + + Args: + image (np.ndarray): + Input image of shape (H, W, 3). + depth_map (np.ndarray or None): + Depth map of shape (H, W), or None if not available. + clockwise (bool): + If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. + + Returns: + tuple: + (rotated_image, rotated_depth_map) + """ + rotated_depth_map = None + if clockwise: + rotated_image = np.transpose(image, (1, 0, 2)) # Transpose height and width + rotated_image = np.flip(rotated_image, axis=1) # Flip horizontally + if depth_map is not None: + rotated_depth_map = np.transpose(depth_map, (1, 0)) + rotated_depth_map = np.flip(rotated_depth_map, axis=1) + else: + rotated_image = np.transpose(image, (1, 0, 2)) # Transpose height and width + rotated_image = np.flip(rotated_image, axis=0) # Flip vertically + if depth_map is not None: + rotated_depth_map = np.transpose(depth_map, (1, 0)) + rotated_depth_map = np.flip(rotated_depth_map, axis=0) + return np.copy(rotated_image), np.copy(rotated_depth_map) + + +def adjust_extrinsic_matrix_rot90(extri_opencv, clockwise): + """ + Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image. + + The rotation is in the image plane. This modifies the camera orientation + accordingly. The function applies either a clockwise or counterclockwise + 90-degree rotation. + + Args: + extri_opencv (np.ndarray): + Extrinsic matrix (3x4) in OpenCV convention. + clockwise (bool): + If True, rotate extrinsic for a 90-degree clockwise image rotation; + otherwise, counterclockwise. + + Returns: + np.ndarray: + A new 3x4 extrinsic matrix after the rotation. + """ + R = extri_opencv[:, :3] + t = extri_opencv[:, 3] + + if clockwise: + R_rotation = np.array([ + [0, -1, 0], + [1, 0, 0], + [0, 0, 1] + ]) + else: + R_rotation = np.array([ + [0, 1, 0], + [-1, 0, 0], + [0, 0, 1] + ]) + + new_R = np.dot(R_rotation, R) + new_t = np.dot(R_rotation, t) + new_extri_opencv = np.hstack((new_R, new_t.reshape(-1, 1))) + return new_extri_opencv + + +def adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise): + """ + Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane. + + Args: + intri_opencv (np.ndarray): + Intrinsic matrix (3x3). + image_width (int): + Original width of the image. + image_height (int): + Original height of the image. + clockwise (bool): + If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. + + Returns: + np.ndarray: + A new 3x3 intrinsic matrix after the rotation. + """ + fx, fy, cx, cy = ( + intri_opencv[0, 0], + intri_opencv[1, 1], + intri_opencv[0, 2], + intri_opencv[1, 2], + ) + + new_intri_opencv = np.eye(3) + if clockwise: + new_intri_opencv[0, 0] = fy + new_intri_opencv[1, 1] = fx + new_intri_opencv[0, 2] = image_height - cy + new_intri_opencv[1, 2] = cx + else: + new_intri_opencv[0, 0] = fy + new_intri_opencv[1, 1] = fx + new_intri_opencv[0, 2] = cy + new_intri_opencv[1, 2] = image_width - cx + + return new_intri_opencv + + +def adjust_track_rot90(track, image_width, image_height, clockwise): + """ + Adjusts a track (N, 2) for a 90-degree rotation of the image in the image plane. + + Args: + track (np.ndarray): + (N, 2) array of pixel coordinates, each row is (x, y). + image_width (int): + Original image width. + image_height (int): + Original image height. + clockwise (bool): + Whether the rotation is 90 degrees clockwise or counterclockwise. + + Returns: + np.ndarray: + A new track of shape (N, 2) after rotation. + """ + if clockwise: + # (x, y) -> (y, image_width - 1 - x) + new_track = np.stack((track[:, 1], image_width - 1 - track[:, 0]), axis=-1) + else: + # (x, y) -> (image_height - 1 - y, x) + new_track = np.stack((image_height - 1 - track[:, 1], track[:, 0]), axis=-1) + + return new_track + + +def read_image_cv2(path: str, rgb: bool = True) -> np.ndarray: + """ + Reads an image from disk using OpenCV, returning it as an RGB image array (H, W, 3). + + Args: + path (str): + File path to the image. + rgb (bool): + If True, convert the image to RGB. + If False, leave the image in BGR/grayscale. + + Returns: + np.ndarray or None: + A numpy array of shape (H, W, 3) if successful, + or None if the file does not exist or could not be read. + """ + if not os.path.exists(path) or os.path.getsize(path) == 0: + print(f"File does not exist or is empty: {path}") + return None + + img = cv2.imread(path) + if img is None: + print(f"Could not load image={path}. Retrying...") + img = cv2.imread(path) + if img is None: + print("Retry failed.") + return None + + if rgb: + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + return img + + +def read_depth(path: str, scale_adjustment=1.0) -> np.ndarray: + """ + Reads a depth map from disk in either .exr or .png format. The .exr is loaded using OpenCV + with the environment variable OPENCV_IO_ENABLE_OPENEXR=1. The .png is assumed to be a 16-bit + PNG (converted from half float). + + Args: + path (str): + File path to the depth image. Must end with .exr or .png. + scale_adjustment (float): + A multiplier for adjusting the loaded depth values (default=1.0). + + Returns: + np.ndarray: + A float32 array (H, W) containing the loaded depth. Zeros or non-finite values + may indicate invalid regions. + + Raises: + ValueError: + If the file extension is not supported. + """ + if path.lower().endswith(".exr"): + # Ensure OPENCV_IO_ENABLE_OPENEXR is set to "1" + d = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[..., 0] + d[d > 1e9] = 0.0 + elif path.lower().endswith(".png"): + d = load_16big_png_depth(path) + else: + raise ValueError(f'unsupported depth file name "{path}"') + + d = d * scale_adjustment + d[~np.isfinite(d)] = 0.0 + + return d + + +def load_16big_png_depth(depth_png: str) -> np.ndarray: + """ + Loads a 16-bit PNG as a half-float depth map (H, W), returning a float32 NumPy array. + + Implementation detail: + - PIL loads 16-bit data as 32-bit "I" mode. + - We reinterpret the bits as float16, then cast to float32. + + Args: + depth_png (str): + File path to the 16-bit PNG. + + Returns: + np.ndarray: + A float32 depth array of shape (H, W). + """ + with Image.open(depth_png) as depth_pil: + depth = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + return depth diff --git a/vggt/training/data/datasets/co3d.py b/vggt/training/data/datasets/co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..5636626d10fe34b012535f9b319ec6ce83917a6c --- /dev/null +++ b/vggt/training/data/datasets/co3d.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import json +import os.path as osp +import os +import logging + +import cv2 +import random +import numpy as np + + +from data.dataset_util import * +from data.base_dataset import BaseDataset + + +SEEN_CATEGORIES = [ + "apple", + "backpack", + "banana", + "baseballbat", + "baseballglove", + "bench", + "bicycle", + "bottle", + "bowl", + "broccoli", + "cake", + "car", + "carrot", + "cellphone", + "chair", + "cup", + "donut", + "hairdryer", + "handbag", + "hydrant", + "keyboard", + "laptop", + "microwave", + "motorcycle", + "mouse", + "orange", + "parkingmeter", + "pizza", + "plant", + "stopsign", + "teddybear", + "toaster", + "toilet", + "toybus", + "toyplane", + "toytrain", + "toytruck", + "tv", + "umbrella", + "vase", + "wineglass", +] + + +class Co3dDataset(BaseDataset): + def __init__( + self, + common_conf, + split: str = "train", + CO3D_DIR: str = None, + CO3D_ANNOTATION_DIR: str = None, + min_num_images: int = 24, + len_train: int = 100000, + len_test: int = 10000, + ): + """ + Initialize the Co3dDataset. + + Args: + common_conf: Configuration object with common settings. + split (str): Dataset split, either 'train' or 'test'. + CO3D_DIR (str): Directory path to CO3D data. + CO3D_ANNOTATION_DIR (str): Directory path to CO3D annotations. + min_num_images (int): Minimum number of images per sequence. + len_train (int): Length of the training dataset. + len_test (int): Length of the test dataset. + Raises: + ValueError: If CO3D_DIR or CO3D_ANNOTATION_DIR is not specified. + """ + super().__init__(common_conf=common_conf) + + self.debug = common_conf.debug + self.training = common_conf.training + self.get_nearby = common_conf.get_nearby + self.load_depth = common_conf.load_depth + self.inside_random = common_conf.inside_random + self.allow_duplicate_img = common_conf.allow_duplicate_img + + if CO3D_DIR is None or CO3D_ANNOTATION_DIR is None: + raise ValueError("Both CO3D_DIR and CO3D_ANNOTATION_DIR must be specified.") + + category = sorted(SEEN_CATEGORIES) + + if self.debug: + category = ["apple"] + + if split == "train": + split_name_list = ["train"] + self.len_train = len_train + elif split == "test": + split_name_list = ["test"] + self.len_train = len_test + else: + raise ValueError(f"Invalid split: {split}") + + self.invalid_sequence = [] # set any invalid sequence names here + + + self.category_map = {} + self.data_store = {} + self.seqlen = None + self.min_num_images = min_num_images + + logging.info(f"CO3D_DIR is {CO3D_DIR}") + + self.CO3D_DIR = CO3D_DIR + self.CO3D_ANNOTATION_DIR = CO3D_ANNOTATION_DIR + + total_frame_num = 0 + + for c in category: + for split_name in split_name_list: + annotation_file = osp.join( + self.CO3D_ANNOTATION_DIR, f"{c}_{split_name}.jgz" + ) + + try: + with gzip.open(annotation_file, "r") as fin: + annotation = json.loads(fin.read()) + except FileNotFoundError: + logging.error(f"Annotation file not found: {annotation_file}") + continue + + for seq_name, seq_data in annotation.items(): + if len(seq_data) < min_num_images: + continue + if seq_name in self.invalid_sequence: + continue + total_frame_num += len(seq_data) + self.data_store[seq_name] = seq_data + + self.sequence_list = list(self.data_store.keys()) + self.sequence_list_len = len(self.sequence_list) + self.total_frame_num = total_frame_num + + status = "Training" if self.training else "Testing" + logging.info(f"{status}: Co3D Data size: {self.sequence_list_len}") + logging.info(f"{status}: Co3D Data dataset length: {len(self)}") + + def get_data( + self, + seq_index: int = None, + img_per_seq: int = None, + seq_name: str = None, + ids: list = None, + aspect_ratio: float = 1.0, + ) -> dict: + """ + Retrieve data for a specific sequence. + + Args: + seq_index (int): Index of the sequence to retrieve. + img_per_seq (int): Number of images per sequence. + seq_name (str): Name of the sequence. + ids (list): Specific IDs to retrieve. + aspect_ratio (float): Aspect ratio for image processing. + + Returns: + dict: A batch of data including images, depths, and other metadata. + """ + if self.inside_random: + seq_index = random.randint(0, self.sequence_list_len - 1) + + if seq_name is None: + seq_name = self.sequence_list[seq_index] + + metadata = self.data_store[seq_name] + + if ids is None: + ids = np.random.choice( + len(metadata), img_per_seq, replace=self.allow_duplicate_img + ) + + annos = [metadata[i] for i in ids] + + target_image_shape = self.get_target_shape(aspect_ratio) + + images = [] + depths = [] + cam_points = [] + world_points = [] + point_masks = [] + extrinsics = [] + intrinsics = [] + image_paths = [] + original_sizes = [] + + for anno in annos: + filepath = anno["filepath"] + + image_path = osp.join(self.CO3D_DIR, filepath) + image = read_image_cv2(image_path) + + if self.load_depth: + depth_path = image_path.replace("/images", "/depths") + ".geometric.png" + depth_map = read_depth(depth_path, 1.0) + + mvs_mask_path = image_path.replace( + "/images", "/depth_masks" + ).replace(".jpg", ".png") + mvs_mask = cv2.imread(mvs_mask_path, cv2.IMREAD_GRAYSCALE) > 128 + depth_map[~mvs_mask] = 0 + + depth_map = threshold_depth_map( + depth_map, min_percentile=-1, max_percentile=98 + ) + else: + depth_map = None + + original_size = np.array(image.shape[:2]) + extri_opencv = np.array(anno["extri"]) + intri_opencv = np.array(anno["intri"]) + + ( + image, + depth_map, + extri_opencv, + intri_opencv, + world_coords_points, + cam_coords_points, + point_mask, + _, + ) = self.process_one_image( + image, + depth_map, + extri_opencv, + intri_opencv, + original_size, + target_image_shape, + filepath=filepath, + ) + + images.append(image) + depths.append(depth_map) + extrinsics.append(extri_opencv) + intrinsics.append(intri_opencv) + cam_points.append(cam_coords_points) + world_points.append(world_coords_points) + point_masks.append(point_mask) + image_paths.append(image_path) + original_sizes.append(original_size) + + set_name = "co3d" + + batch = { + "seq_name": set_name + "_" + seq_name, + "ids": ids, + "frame_num": len(extrinsics), + "images": images, + "depths": depths, + "extrinsics": extrinsics, + "intrinsics": intrinsics, + "cam_points": cam_points, + "world_points": world_points, + "point_masks": point_masks, + "original_sizes": original_sizes, + } + return batch diff --git a/vggt/training/data/datasets/vkitti.py b/vggt/training/data/datasets/vkitti.py new file mode 100644 index 0000000000000000000000000000000000000000..7acc9f56179581d2ed8740dd9beaba1b44a68caf --- /dev/null +++ b/vggt/training/data/datasets/vkitti.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import os.path as osp +import logging +import random +import glob + +import cv2 +import numpy as np + +from data.dataset_util import * +from data.base_dataset import BaseDataset + + +class VKittiDataset(BaseDataset): + def __init__( + self, + common_conf, + split: str = "train", + VKitti_DIR: str = "/checkpoint/repligen/jianyuan/datasets/vkitti/", + min_num_images: int = 24, + len_train: int = 100000, + len_test: int = 10000, + expand_ratio: int = 8, + ): + """ + Initialize the VKittiDataset. + + Args: + common_conf: Configuration object with common settings. + split (str): Dataset split, either 'train' or 'test'. + VKitti_DIR (str): Directory path to VKitti data. + min_num_images (int): Minimum number of images per sequence. + len_train (int): Length of the training dataset. + len_test (int): Length of the test dataset. + expand_range (int): Range for expanding nearby image selection. + get_nearby_thres (int): Threshold for nearby image selection. + """ + super().__init__(common_conf=common_conf) + + self.debug = common_conf.debug + self.training = common_conf.training + self.get_nearby = common_conf.get_nearby + self.inside_random = common_conf.inside_random + self.allow_duplicate_img = common_conf.allow_duplicate_img + + self.expand_ratio = expand_ratio + self.VKitti_DIR = VKitti_DIR + self.min_num_images = min_num_images + + if split == "train": + self.len_train = len_train + elif split == "test": + self.len_train = len_test + else: + raise ValueError(f"Invalid split: {split}") + + logging.info(f"VKitti_DIR is {self.VKitti_DIR}") + + # Load or generate sequence list + txt_path = osp.join(self.VKitti_DIR, "sequence_list.txt") + if osp.exists(txt_path): + with open(txt_path, 'r') as f: + sequence_list = [line.strip() for line in f.readlines()] + else: + # Generate sequence list and save to txt + sequence_list = glob.glob(osp.join(self.VKitti_DIR, "*/*/*/rgb/*")) + sequence_list = [file_path.split(self.VKitti_DIR)[-1].lstrip('/') for file_path in sequence_list] + sequence_list = sorted(sequence_list) + + # Save to txt file + with open(txt_path, 'w') as f: + f.write('\n'.join(sequence_list)) + + self.sequence_list = sequence_list + self.sequence_list_len = len(self.sequence_list) + + self.depth_max = 80 + + status = "Training" if self.training else "Testing" + logging.info(f"{status}: VKitti Real Data size: {self.sequence_list_len}") + logging.info(f"{status}: VKitti Data dataset length: {len(self)}") + + def get_data( + self, + seq_index: int = None, + img_per_seq: int = None, + seq_name: str = None, + ids: list = None, + aspect_ratio: float = 1.0, + ) -> dict: + """ + Retrieve data for a specific sequence. + + Args: + seq_index (int): Index of the sequence to retrieve. + img_per_seq (int): Number of images per sequence. + seq_name (str): Name of the sequence. + ids (list): Specific IDs to retrieve. + aspect_ratio (float): Aspect ratio for image processing. + + Returns: + dict: A batch of data including images, depths, and other metadata. + """ + if self.inside_random and self.training: + seq_index = random.randint(0, self.sequence_list_len - 1) + + if seq_name is None: + seq_name = self.sequence_list[seq_index] + + camera_id = int(seq_name[-1]) + + # Load camera parameters + try: + camera_parameters = np.loadtxt( + osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "extrinsic.txt"), + delimiter=" ", + skiprows=1 + ) + camera_parameters = camera_parameters[camera_parameters[:, 1] == camera_id] + + camera_intrinsic = np.loadtxt( + osp.join(self.VKitti_DIR, "/".join(seq_name.split("/")[:2]), "intrinsic.txt"), + delimiter=" ", + skiprows=1 + ) + camera_intrinsic = camera_intrinsic[camera_intrinsic[:, 1] == camera_id] + except Exception as e: + logging.error(f"Error loading camera parameters for {seq_name}: {e}") + raise + + num_images = len(camera_parameters) + + if ids is None: + ids = np.random.choice(num_images, img_per_seq, replace=self.allow_duplicate_img) + + if self.get_nearby: + ids = self.get_nearby_ids(ids, num_images, expand_ratio=self.expand_ratio) + + target_image_shape = self.get_target_shape(aspect_ratio) + + images = [] + depths = [] + cam_points = [] + world_points = [] + point_masks = [] + extrinsics = [] + intrinsics = [] + original_sizes = [] + + for image_idx in ids: + image_filepath = osp.join(self.VKitti_DIR, seq_name, f"rgb_{image_idx:05d}.jpg") + depth_filepath = osp.join(self.VKitti_DIR, seq_name, f"depth_{image_idx:05d}.png").replace("/rgb", "/depth") + + image = read_image_cv2(image_filepath) + depth_map = cv2.imread(depth_filepath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + depth_map = depth_map / 100 + depth_map = threshold_depth_map(depth_map, max_percentile=-1, min_percentile=-1, max_depth=self.depth_max) + + assert image.shape[:2] == depth_map.shape, f"Image and depth shape mismatch: {image.shape[:2]} vs {depth_map.shape}" + + original_size = np.array(image.shape[:2]) + + # Process camera matrices + extri_opencv = camera_parameters[image_idx][2:].reshape(4, 4) + extri_opencv = extri_opencv[:3] + + intri_opencv = np.eye(3) + intri_opencv[0, 0] = camera_intrinsic[image_idx][-4] + intri_opencv[1, 1] = camera_intrinsic[image_idx][-3] + intri_opencv[0, 2] = camera_intrinsic[image_idx][-2] + intri_opencv[1, 2] = camera_intrinsic[image_idx][-1] + + ( + image, + depth_map, + extri_opencv, + intri_opencv, + world_coords_points, + cam_coords_points, + point_mask, + _, + ) = self.process_one_image( + image, + depth_map, + extri_opencv, + intri_opencv, + original_size, + target_image_shape, + filepath=image_filepath, + ) + + if (image.shape[:2] != target_image_shape).any(): + logging.error(f"Wrong shape for {seq_name}: expected {target_image_shape}, got {image.shape[:2]}") + continue + + images.append(image) + depths.append(depth_map) + extrinsics.append(extri_opencv) + intrinsics.append(intri_opencv) + cam_points.append(cam_coords_points) + world_points.append(world_coords_points) + point_masks.append(point_mask) + original_sizes.append(original_size) + + set_name = "vkitti" + batch = { + "seq_name": set_name + "_" + seq_name, + "ids": ids, + "frame_num": len(extrinsics), + "images": images, + "depths": depths, + "extrinsics": extrinsics, + "intrinsics": intrinsics, + "cam_points": cam_points, + "world_points": world_points, + "point_masks": point_masks, + "original_sizes": original_sizes, + } + return batch \ No newline at end of file diff --git a/vggt/training/data/dynamic_dataloader.py b/vggt/training/data/dynamic_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..11d1224b80283e7c8149807e9d64bd97e37c8d28 --- /dev/null +++ b/vggt/training/data/dynamic_dataloader.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from hydra.utils import instantiate +import random +import numpy as np +from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler +from abc import ABC, abstractmethod + +from .worker_fn import get_worker_init_fn + +class DynamicTorchDataset(ABC): + def __init__( + self, + dataset: dict, + common_config: dict, + num_workers: int, + shuffle: bool, + pin_memory: bool, + drop_last: bool = True, + collate_fn: Optional[Callable] = None, + worker_init_fn: Optional[Callable] = None, + persistent_workers: bool = False, + seed: int = 42, + max_img_per_gpu: int = 48, + ) -> None: + self.dataset_config = dataset + self.common_config = common_config + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + self.collate_fn = collate_fn + self.worker_init_fn = worker_init_fn + self.persistent_workers = persistent_workers + self.seed = seed + self.max_img_per_gpu = max_img_per_gpu + + # Instantiate the dataset + self.dataset = instantiate(dataset, common_config=common_config, _recursive_=False) + + # Extract aspect ratio and image number ranges from the configuration + self.aspect_ratio_range = common_config.augs.aspects # e.g., [0.5, 1.0] + self.image_num_range = common_config.img_nums # e.g., [2, 24] + + # Validate the aspect ratio and image number ranges + if len(self.aspect_ratio_range) != 2 or self.aspect_ratio_range[0] > self.aspect_ratio_range[1]: + raise ValueError(f"aspect_ratio_range must be [min, max] with min <= max, got {self.aspect_ratio_range}") + if len(self.image_num_range) != 2 or self.image_num_range[0] < 1 or self.image_num_range[0] > self.image_num_range[1]: + raise ValueError(f"image_num_range must be [min, max] with 1 <= min <= max, got {self.image_num_range}") + + # Create samplers + self.sampler = DynamicDistributedSampler(self.dataset, seed=seed, shuffle=shuffle) + self.batch_sampler = DynamicBatchSampler( + self.sampler, + self.aspect_ratio_range, + self.image_num_range, + seed=seed, + max_img_per_gpu=max_img_per_gpu + ) + + def get_loader(self, epoch): + print("Building dynamic dataloader with epoch:", epoch) + + # Set the epoch for the sampler + self.sampler.set_epoch(epoch) + if hasattr(self.dataset, "epoch"): + self.dataset.epoch = epoch + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + # Create and return the dataloader + return DataLoader( + self.dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + batch_sampler=self.batch_sampler, + collate_fn=self.collate_fn, + persistent_workers=self.persistent_workers, + worker_init_fn=get_worker_init_fn( + seed=self.seed, + num_workers=self.num_workers, + epoch=epoch, + worker_init_fn=self.worker_init_fn, + ), + ) + + +class DynamicBatchSampler(Sampler): + """ + A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number + for each sample. Batches within a sample share the same aspect ratio and image number. + """ + def __init__(self, + sampler, + aspect_ratio_range, + image_num_range, + epoch=0, + seed=42, + max_img_per_gpu=48): + """ + Initializes the dynamic batch sampler. + + Args: + sampler: Instance of DynamicDistributedSampler. + aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio]. + image_num_range: List containing [min_images, max_images] per sample. + epoch: Current epoch number. + seed: Random seed for reproducibility. + max_img_per_gpu: Maximum number of images to fit in GPU memory. + """ + self.sampler = sampler + self.aspect_ratio_range = aspect_ratio_range + self.image_num_range = image_num_range + self.rng = random.Random() + + # Uniformly sample from the range of possible image numbers + # For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here. + self.image_num_weights = {num_images: 1.0 for num_images in range(image_num_range[0], image_num_range[1]+1)} + + # Possible image numbers, e.g., [2, 3, 4, ..., 24] + self.possible_nums = np.array([n for n in self.image_num_weights.keys() + if self.image_num_range[0] <= n <= self.image_num_range[1]]) + + # Normalize weights for sampling + weights = [self.image_num_weights[n] for n in self.possible_nums] + self.normalized_weights = np.array(weights) / sum(weights) + + # Maximum image number per GPU + self.max_img_per_gpu = max_img_per_gpu + + # Set the epoch for the sampler + self.set_epoch(epoch + seed) + + def set_epoch(self, epoch): + """ + Sets the epoch for this sampler, affecting the random sequence. + + Args: + epoch: The epoch number. + """ + self.sampler.set_epoch(epoch) + self.epoch = epoch + self.rng.seed(epoch * 100) + + def __iter__(self): + """ + Yields batches of samples with synchronized dynamic parameters. + + Returns: + Iterator yielding batches of indices with associated parameters. + """ + sampler_iterator = iter(self.sampler) + + while True: + try: + # Sample random image number and aspect ratio + random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights)) + random_aspect_ratio = round(self.rng.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]), 2) + + # Update sampler parameters + self.sampler.update_parameters( + aspect_ratio=random_aspect_ratio, + image_num=random_image_num + ) + + # Calculate batch size based on max images per GPU and current image number + batch_size = self.max_img_per_gpu / random_image_num + batch_size = np.floor(batch_size).astype(int) + batch_size = max(1, batch_size) # Ensure batch size is at least 1 + + # Collect samples for the current batch + current_batch = [] + for _ in range(batch_size): + try: + item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num) + current_batch.append(item) + except StopIteration: + break # No more samples + + if not current_batch: + break # No more data to yield + + yield current_batch + + except StopIteration: + break # End of sampler's iterator + + def __len__(self): + # Return a large dummy length + return 1000000 + + +class DynamicDistributedSampler(DistributedSampler): + """ + Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num + parameters, which can be passed into the dataset's __getitem__ method. + """ + def __init__( + self, + dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + seed: int = 0, + drop_last: bool = False, + ): + super().__init__( + dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last + ) + self.aspect_ratio = None + self.image_num = None + + def __iter__(self): + """ + Yields a sequence of (index, image_num, aspect_ratio). + Relies on the parent class's logic for shuffling/distributing + the indices across replicas, then attaches extra parameters. + """ + indices_iter = super().__iter__() + + for idx in indices_iter: + yield (idx, self.image_num, self.aspect_ratio,) + + def update_parameters(self, aspect_ratio, image_num): + """ + Updates dynamic parameters for each new epoch or iteration. + + Args: + aspect_ratio: The aspect ratio to set. + image_num: The number of images to set. + """ + self.aspect_ratio = aspect_ratio + self.image_num = image_num diff --git a/vggt/training/data/preprocess/vkitti.sh b/vggt/training/data/preprocess/vkitti.sh new file mode 100644 index 0000000000000000000000000000000000000000..40063f95d32727948d22d1a8a196b15d8e71d3f6 --- /dev/null +++ b/vggt/training/data/preprocess/vkitti.sh @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +mkdir vkitti +cd vkitti + +wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_rgb.tar +tar -xvf vkitti_2.0.3_rgb.tar + +wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_depth.tar +tar -xvf vkitti_2.0.3_depth.tar + +wget https://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_textgt.tar.gz +tar -xvf vkitti_2.0.3_textgt.tar.gz + + +cd .. + diff --git a/vggt/training/data/track_util.py b/vggt/training/data/track_util.py new file mode 100644 index 0000000000000000000000000000000000000000..30bcf1041f266aef7b748ef9c22e3d525b20f54d --- /dev/null +++ b/vggt/training/data/track_util.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import cv2 +import numpy as np +import torch + +import logging + +from vggt.utils.geometry import * + + + +def build_tracks_by_depth(extrinsics, intrinsics, world_points, depths, point_masks, images, pos_rel_thres=0.05, neg_epipolar_thres=16, + boundary_thres=4, target_track_num=512, neg_ratio = 0.0, neg_sample_size_ratio = 0.5, seq_name=None): + """ + Args: + extrinsics: (N, 3, 4) + intrinsics: (N, 3, 3) + world_points: (N, H, W, 3) + depths: (N, H, W) + point_masks: (N, H, W) + pos_rel_thres: float, relative threshold for positive track depth check + neg_epipolar_thres: float, threshold for negative track epipolar check, in px + boundary_thres: int, boundary in px to skip near edges + target_track_num: int, total # tracks to build + neg_ratio: fraction of final tracks that should be negative + neg_sample_size_ratio: fraction of W/H used for random offset + + Returns: + final_tracks: (N, P, 2) float + final_vis_masks: (N, P) bool + final_pos_masks: (P) bool, indicate if a mask is positive or negative + """ + # Wait, should we do this before resizing the image? + + B, H, W, _ = world_points.shape + + # We use the first frame as the query frame, so [0] + query_world_points = world_points[0] + query_point_masks = point_masks[0] + + + if (query_point_masks).sum() > 0: + # at least one point + valid_query_points = query_world_points[query_point_masks] + + # image_points: BxPx2 + # cam_points: Bx3xP (yes 3xP instead of Px3). Probably we can change it in the future + image_points, cam_points = project_world_points_to_cam(valid_query_points, extrinsics, intrinsics) + + # proj_depths: BxP + proj_depths = cam_points[:, -1] + + # floor to get the left top corner + uv_int = image_points.floor().long().clone() + + uv_inside_flag = (uv_int[..., 0] >= boundary_thres) & (uv_int[..., 0] < (W - boundary_thres)) & (uv_int[..., 1] >= boundary_thres) & (uv_int[..., 1] < (H - boundary_thres)) + uv_int[~uv_inside_flag] = 0 + batch_indices = torch.arange(B).view(B, 1).expand(-1, uv_int.shape[1]) + + # Use these indices to sample from the depth map + # since we interpolate depths by nearest, + # so assume the left top corner is (x, y) + # we want to check for (x,y), (x+1,y), (x,y+1), (x+1,y+1) + + depth_inside_flag = None + for shift in [(0,0), (1,0), (0,1), (1,1)]: + cur_uv_int = uv_int + torch.tensor(shift) + cur_depth_inside_flag = get_depth_inside_flag(depths, batch_indices, cur_uv_int, proj_depths, pos_rel_thres) + if depth_inside_flag is None: + depth_inside_flag = cur_depth_inside_flag + else: + depth_inside_flag = torch.logical_or(depth_inside_flag, cur_depth_inside_flag) + + # B, P, 2 + positive_tracks = image_points + positive_vis_masks = torch.logical_and(uv_inside_flag, depth_inside_flag) + else: + print(f"No valid query points in {seq_name}") + positive_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32) + positive_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool) + + + sampled_neg_track_num = target_track_num * 4 # we sample more negative tracks to ensure the quality + + perb_range = [int(W*neg_sample_size_ratio), int(H*neg_sample_size_ratio)] + + # sample negative query points + us = torch.randint(low=0, high=W, size=(1, sampled_neg_track_num), device=world_points.device) + vs = torch.randint(low=0, high=H, size=(1, sampled_neg_track_num), device=world_points.device) + neg_query_uvs = torch.stack([us, vs], dim=-1) + + # construct negative tracks + delta_us = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[0] + delta_vs = torch.rand(size=(B, sampled_neg_track_num), device=world_points.device) * perb_range[1] + delta_us[0] = 0 + delta_vs[0] = 0 + negative_tracks = neg_query_uvs + torch.stack([delta_us, delta_vs], dim=-1) + + # Do epipolar check here + negative_sampson_distances = track_epipolar_check(negative_tracks, extrinsics, intrinsics) + negative_epipolar_check = (negative_sampson_distances > neg_epipolar_thres).all(dim=0) # we set the threshold to 5 px + # Filter out those satifsfying epipolar check + negative_tracks = negative_tracks[:, negative_epipolar_check] + + # Prepare for output + final_tracks = torch.zeros(B, target_track_num, 2, device=world_points.device, dtype=torch.float32) + final_vis_masks = torch.zeros(B, target_track_num, device=world_points.device, dtype=torch.bool) + final_pos_masks = torch.zeros(target_track_num, device=world_points.device, dtype=torch.bool) + + target_pos_track_num = target_track_num - int(target_track_num * neg_ratio) + sampled_pos_track_num = 0 + + sampled_positive_tracks, sampled_positive_vis_masks = sample_positive_tracks(positive_tracks, positive_vis_masks, target_pos_track_num) + sampled_pos_track_num = sampled_positive_tracks.shape[1] + final_tracks[:, :sampled_pos_track_num] = sampled_positive_tracks + final_vis_masks[:, :sampled_pos_track_num] = sampled_positive_vis_masks + final_pos_masks[:sampled_pos_track_num] = True + + + target_neg_track_num = target_track_num - sampled_pos_track_num + + # Now we need to sample negative tracks + # just do simple random sampling + rand_indices = torch.randperm(negative_tracks.shape[1], device=negative_tracks.device) + sampled_neg_tracks = negative_tracks[:, rand_indices[:target_neg_track_num]] + sampled_neg_track_num = sampled_neg_tracks.shape[1] + final_tracks[:, sampled_pos_track_num:sampled_pos_track_num+sampled_neg_track_num] = sampled_neg_tracks + + if sampled_pos_track_num+sampled_neg_track_num!=target_track_num: + logging.warning(f"sampled_pos_track_num+sampled_neg_track_num!=target_track_num: {sampled_pos_track_num+sampled_neg_track_num} != {target_track_num}") + # Do not need to set final_vis_masks and final_pos_masks, because they are all False + # Do not need to check the shape of final_tracks, as it is zeroed out + + + # NOTE: We need to do some visual checks + + + return final_tracks, final_vis_masks, final_pos_masks + + + +def get_depth_inside_flag(depths, batch_indices, uv_int, proj_depths, rel_thres): + sampled_depths = depths[batch_indices, uv_int[..., 1], uv_int[..., 0]] + depth_diff = (proj_depths - sampled_depths).abs() + depth_inside_flag = torch.logical_and(depth_diff < (proj_depths * rel_thres), depth_diff < (sampled_depths * rel_thres)) + return depth_inside_flag + + + + + + + +def sample_positive_tracks(tracks, tracks_mask, track_num, half_top = True, seq_name=None): + # tracks: (B, T, 2) + # tracks_mask: (B, T) + # track_num: int + # half_top: bool + + # if the query frame is not valid, then the track is not valid + tracks_mask[:, tracks_mask[0]==False] = False + + track_frame_num = tracks_mask.sum(dim=0) + tracks_mask[:, track_frame_num<=1] = False + track_frame_num = tracks_mask.sum(dim=0) + + _, track_num_sort_idx = track_frame_num.sort(descending=True) + + if half_top: + if len(track_num_sort_idx)//2 > track_num: + # drop those tracks with too small number of valid frames + # track_num_sort_idx = track_num_sort_idx[:track_num] + track_num_sort_idx = track_num_sort_idx[:len(track_num_sort_idx)//2] + + pick_idx = torch.randperm(len(track_num_sort_idx))[:track_num] + track_num_sort_idx = track_num_sort_idx[pick_idx] + + tracks = tracks[:, track_num_sort_idx].clone() + tracks_mask = tracks_mask[:, track_num_sort_idx].clone() + + + tracks_mask = tracks_mask.bool() # ensure the type is bool + return tracks, tracks_mask + + + + + +# Only for Debugging and Visualization + +def track_epipolar_check(tracks, extrinsics, intrinsics, use_essential_mat = False): + from kornia.geometry.epipolar import sampson_epipolar_distance + + B, T, _ = tracks.shape + essential_mats = get_essential_matrix(extrinsics[0:1].expand(B-1, -1, -1), extrinsics[1:]) + + if use_essential_mat: + tracks_normalized = cam_from_img(tracks, intrinsics) + sampson_distances = sampson_epipolar_distance(tracks_normalized[0:1].expand(B-1, -1, -1), tracks_normalized[1:], essential_mats) + else: + K1 = intrinsics[0:1].expand(B-1, -1, -1) + K2 = intrinsics[1:].expand(B-1, -1, -1) + fundamental_mats = K2.inverse().permute(0, 2, 1).matmul(essential_mats).matmul(K1.inverse()) + sampson_distances = sampson_epipolar_distance(tracks[0:1].expand(B-1, -1, -1), tracks[1:], fundamental_mats) + + return sampson_distances + + +def get_essential_matrix(extrinsic1, extrinsic2): + R1 = extrinsic1[:, :3, :3] + t1 = extrinsic1[:, :3, 3] + R2 = extrinsic2[:, :3, :3] + t2 = extrinsic2[:, :3, 3] + + R12 = R2.matmul(R1.permute(0, 2, 1)) + t12 = t2 - R12.matmul(t1[..., None])[..., 0] + E_R = R12 + E_t = -E_R.permute(0, 2, 1).matmul(t12[..., None])[..., 0] + E = E_R.matmul(hat(E_t)) + return E + + + +def hat(v: torch.Tensor) -> torch.Tensor: + N, dim = v.shape + if dim != 3: + raise ValueError("Input vectors have to be 3-dimensional.") + + x, y, z = v.unbind(1) + + h_01 = -z.view(N, 1, 1) + h_02 = y.view(N, 1, 1) + h_10 = z.view(N, 1, 1) + h_12 = -x.view(N, 1, 1) + h_20 = -y.view(N, 1, 1) + h_21 = x.view(N, 1, 1) + + zeros = torch.zeros((N, 1, 1), dtype=v.dtype, device=v.device) + + row1 = torch.cat((zeros, h_01, h_02), dim=2) + row2 = torch.cat((h_10, zeros, h_12), dim=2) + row3 = torch.cat((h_20, h_21, zeros), dim=2) + + h = torch.cat((row1, row2, row3), dim=1) + + return h + + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position( + tracks_b, + vis_mask_b=None, + image_width=None, + image_height=None, + cmap_name="hsv" +): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy( + x, y, + W=image_width, + H=image_height, + cmap_name=cmap_name + ) + # scale to [0,255] + r, g, b = int(r*255), int(g*255), int(b*255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv" # e.g. "hsv", "rainbow", "jet" +): + """ + Visualizes all frames for each sample (b) in ONE horizontal row, saving + one PNG per sample. Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + + Args: + images: torch.Tensor (B, S, 3, H, W) if CHW or (B, S, H, W, 3) if HWC. + tracks: torch.Tensor (B, S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (B, S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + + Returns: + None (saves images in out_dir). + """ + import matplotlib + matplotlib.use('Agg') # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + B, S = images.shape[0], images.shape[1] + _, _, N, _ = tracks.shape # (B, S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[b, s].shape = (3, H, W) + H, W = images.shape[3], images.shape[4] + else: + # e.g. images[b, s].shape = (H, W, 3) + H, W = images.shape[2], images.shape[3] + + for b in range(B): + # Pre-compute the color for each track i based on first visible position + # in sample b: + track_colors_rgb = get_track_colors_by_position( + tracks[b], # shape (S, N, 2) + vis_mask_b=track_vis_mask[b] if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name + ) + # We'll accumulate each frame’s drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[b, s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, the image is assumed BGR, + # but *currently* it's in (R,G,B) if your original is truly RGB. + # We'll do the color conversion AFTER drawing so that we can call + # cv2.circle(...) with BGR color. + # That means we need to swap the channels now to get BGR for drawing. + # If your images are actually BGR, you may skip or adapt. + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[b, s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[b, s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + frame_images.append(img_rgb) + + # Concatenate all frames horizontally: (H, S*W, 3) + row_img = np.concatenate(frame_images, axis=1) + + out_path = os.path.join(out_dir, f"tracks_b{b}.png") + cv2.imwrite(out_path, row_img) + print(f"[INFO] Saved color-by-XY track visualization for sample b={b} -> {out_path}") \ No newline at end of file diff --git a/vggt/training/data/worker_fn.py b/vggt/training/data/worker_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..425ce2f49a4be7832178fef70ef2bc887eec1ba8 --- /dev/null +++ b/vggt/training/data/worker_fn.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for distributed training and deterministic seed generation. +This module provides functions for working with PyTorch's distributed +training capabilities and ensuring reproducible data loading. +""" + +import os +import torch +import random +import numpy as np + +import torch.distributed as dist +from functools import partial + + +def is_dist_avail_and_initialized(): + """ + Check if distributed training is available and initialized. + + Returns: + bool: True if distributed training is available and initialized, False otherwise. + """ + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + """ + Get the rank of the current process in distributed training. + + Returns: + int: The rank of the current process, or 0 if distributed training is not initialized. + """ + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_world_size(): + """ + Get the total number of processes in distributed training. + + Returns: + int: The world size, or 1 if distributed training is not initialized. + """ + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def default_worker_init_fn(worker_id, num_workers, epoch, seed=0): + """ + Default function to initialize random seeds for dataloader workers. + + Ensures that each worker across different ranks, epochs, and world sizes + gets a unique random seed for reproducibility. + + Args: + worker_id (int): ID of the dataloader worker. + num_workers (int): Total number of dataloader workers. + epoch (int): Current training epoch. + seed (int, optional): Base seed for randomization. Defaults to 0. + """ + rank = get_rank() + world_size = get_world_size() + distributed_rank = int(os.environ.get("RANK", None)) + + # Use prime numbers for better distribution + RANK_MULTIPLIER = 1 + WORKER_MULTIPLIER = 1 + WORLD_MULTIPLIER = 1 + EPOCH_MULTIPLIER = 12345 + DISTRIBUTED_RANK_MULTIPLIER = 1042 + + worker_seed = ( + rank * num_workers * RANK_MULTIPLIER + + worker_id * WORKER_MULTIPLIER + + seed + + world_size * WORLD_MULTIPLIER + + epoch * EPOCH_MULTIPLIER + + distributed_rank * DISTRIBUTED_RANK_MULTIPLIER + ) + + print(f"Rank: {rank}, World size: {world_size}, Distributed rank: {distributed_rank}") + print(f"Worker seed: {worker_seed}") + + + torch.random.manual_seed(worker_seed) + np.random.seed(worker_seed) + random.seed(worker_seed) + return + +def get_worker_init_fn(seed, num_workers, epoch, worker_init_fn=None): + """ + Get a worker initialization function for dataloaders. + + Args: + seed (int): Base seed for randomization. + num_workers (int): Number of dataloader workers. + epoch (int): Current training epoch. + worker_init_fn (callable, optional): Custom worker initialization function. + If provided, this will be returned instead of the default one. + + Returns: + callable: A worker initialization function to use with DataLoader. + """ + if worker_init_fn is not None: + return worker_init_fn + + return partial( + default_worker_init_fn, + num_workers=num_workers, + epoch=epoch, + seed=seed, + ) diff --git a/vggt/training/launch.py b/vggt/training/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..5511f510bbc84f2bced9442543b079f10fe58698 --- /dev/null +++ b/vggt/training/launch.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from hydra import initialize, compose +from omegaconf import DictConfig, OmegaConf +from trainer import Trainer + + +def main(): + parser = argparse.ArgumentParser(description="Train model with configurable YAML file") + parser.add_argument( + "--config", + type=str, + default="default", + help="Name of the config file (without .yaml extension, default: default)" + ) + args = parser.parse_args() + + with initialize(version_base=None, config_path="config"): + cfg = compose(config_name=args.config) + + trainer = Trainer(**cfg) + trainer.run() + + +if __name__ == "__main__": + main() + + diff --git a/vggt/training/loss.py b/vggt/training/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f919c7c1aa4ea8a7f6aa1a90fbf510c8c8624210 --- /dev/null +++ b/vggt/training/loss.py @@ -0,0 +1,809 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from dataclasses import dataclass +from vggt.utils.pose_enc import extri_intri_to_pose_encoding +from train_utils.general import check_and_fix_inf_nan +from math import ceil, floor + + +@dataclass(eq=False) +class MultitaskLoss(torch.nn.Module): + """ + Multi-task loss module that combines different loss types for VGGT. + + Supports: + - Camera loss + - Depth loss + - Point loss + - Tracking loss (not cleaned yet, dirty code is at the bottom of this file) + """ + def __init__(self, camera=None, depth=None, point=None, track=None, **kwargs): + super().__init__() + # Loss configuration dictionaries for each task + self.camera = camera + self.depth = depth + self.point = point + self.track = track + + def forward(self, predictions, batch) -> torch.Tensor: + """ + Compute the total multi-task loss. + + Args: + predictions: Dict containing model predictions for different tasks + batch: Dict containing ground truth data and masks + + Returns: + Dict containing individual losses and total objective + """ + total_loss = 0 + loss_dict = {} + + # Camera pose loss - if pose encodings are predicted + if "pose_enc_list" in predictions: + camera_loss_dict = compute_camera_loss(predictions, batch, **self.camera) + camera_loss = camera_loss_dict["loss_camera"] * self.camera["weight"] + total_loss = total_loss + camera_loss + loss_dict.update(camera_loss_dict) + + # Depth estimation loss - if depth maps are predicted + if "depth" in predictions: + depth_loss_dict = compute_depth_loss(predictions, batch, **self.depth) + depth_loss = depth_loss_dict["loss_conf_depth"] + depth_loss_dict["loss_reg_depth"] + depth_loss_dict["loss_grad_depth"] + depth_loss = depth_loss * self.depth["weight"] + total_loss = total_loss + depth_loss + loss_dict.update(depth_loss_dict) + + # 3D point reconstruction loss - if world points are predicted + if "world_points" in predictions: + point_loss_dict = compute_point_loss(predictions, batch, **self.point) + point_loss = point_loss_dict["loss_conf_point"] + point_loss_dict["loss_reg_point"] + point_loss_dict["loss_grad_point"] + point_loss = point_loss * self.point["weight"] + total_loss = total_loss + point_loss + loss_dict.update(point_loss_dict) + + # Tracking loss - not cleaned yet, dirty code is at the bottom of this file + if "track" in predictions: + raise NotImplementedError("Track loss is not cleaned up yet") + + loss_dict["objective"] = total_loss + + return loss_dict + + +def compute_camera_loss( + pred_dict, # predictions dict, contains pose encodings + batch_data, # ground truth and mask batch dict + loss_type="l1", # "l1" or "l2" loss + gamma=0.6, # temporal decay weight for multi-stage training + pose_encoding_type="absT_quaR_FoV", + weight_trans=1.0, # weight for translation loss + weight_rot=1.0, # weight for rotation loss + weight_focal=0.5, # weight for focal length loss + **kwargs +): + # List of predicted pose encodings per stage + pred_pose_encodings = pred_dict['pose_enc_list'] + # Binary mask for valid points per frame (B, N, H, W) + point_masks = batch_data['point_masks'] + # Only consider frames with enough valid points (>100) + valid_frame_mask = point_masks[:, 0].sum(dim=[-1, -2]) > 100 + # Number of prediction stages + n_stages = len(pred_pose_encodings) + + # Get ground truth camera extrinsics and intrinsics + gt_extrinsics = batch_data['extrinsics'] + gt_intrinsics = batch_data['intrinsics'] + image_hw = batch_data['images'].shape[-2:] + + # Encode ground truth pose to match predicted encoding format + gt_pose_encoding = extri_intri_to_pose_encoding( + gt_extrinsics, gt_intrinsics, image_hw, pose_encoding_type=pose_encoding_type + ) + + # Initialize loss accumulators for translation, rotation, focal length + total_loss_T = total_loss_R = total_loss_FL = 0 + + # Compute loss for each prediction stage with temporal weighting + for stage_idx in range(n_stages): + # Later stages get higher weight (gamma^0 = 1.0 for final stage) + stage_weight = gamma ** (n_stages - stage_idx - 1) + pred_pose_stage = pred_pose_encodings[stage_idx] + + if valid_frame_mask.sum() == 0: + # If no valid frames, set losses to zero to avoid gradient issues + loss_T_stage = (pred_pose_stage * 0).mean() + loss_R_stage = (pred_pose_stage * 0).mean() + loss_FL_stage = (pred_pose_stage * 0).mean() + else: + # Only consider valid frames for loss computation + loss_T_stage, loss_R_stage, loss_FL_stage = camera_loss_single( + pred_pose_stage[valid_frame_mask].clone(), + gt_pose_encoding[valid_frame_mask].clone(), + loss_type=loss_type + ) + # Accumulate weighted losses across stages + total_loss_T += loss_T_stage * stage_weight + total_loss_R += loss_R_stage * stage_weight + total_loss_FL += loss_FL_stage * stage_weight + + # Average over all stages + avg_loss_T = total_loss_T / n_stages + avg_loss_R = total_loss_R / n_stages + avg_loss_FL = total_loss_FL / n_stages + + # Compute total weighted camera loss + total_camera_loss = ( + avg_loss_T * weight_trans + + avg_loss_R * weight_rot + + avg_loss_FL * weight_focal + ) + + # Return loss dictionary with individual components + return { + "loss_camera": total_camera_loss, + "loss_T": avg_loss_T, + "loss_R": avg_loss_R, + "loss_FL": avg_loss_FL + } + +def camera_loss_single(pred_pose_enc, gt_pose_enc, loss_type="l1"): + """ + Computes translation, rotation, and focal loss for a batch of pose encodings. + + Args: + pred_pose_enc: (N, D) predicted pose encoding + gt_pose_enc: (N, D) ground truth pose encoding + loss_type: "l1" (abs error) or "l2" (euclidean error) + Returns: + loss_T: translation loss (mean) + loss_R: rotation loss (mean) + loss_FL: focal length/intrinsics loss (mean) + + NOTE: The paper uses smooth l1 loss, but we found l1 loss is more stable than smooth l1 and l2 loss. + So here we use l1 loss. + """ + if loss_type == "l1": + # Translation: first 3 dims; Rotation: next 4 (quaternion); Focal/Intrinsics: last dims + loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).abs() + loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).abs() + loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).abs() + elif loss_type == "l2": + # L2 norm for each component + loss_T = (pred_pose_enc[..., :3] - gt_pose_enc[..., :3]).norm(dim=-1, keepdim=True) + loss_R = (pred_pose_enc[..., 3:7] - gt_pose_enc[..., 3:7]).norm(dim=-1) + loss_FL = (pred_pose_enc[..., 7:] - gt_pose_enc[..., 7:]).norm(dim=-1) + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + # Check/fix numerical issues (nan/inf) for each loss component + loss_T = check_and_fix_inf_nan(loss_T, "loss_T") + loss_R = check_and_fix_inf_nan(loss_R, "loss_R") + loss_FL = check_and_fix_inf_nan(loss_FL, "loss_FL") + + # Clamp outlier translation loss to prevent instability, then average + loss_T = loss_T.clamp(max=100).mean() + loss_R = loss_R.mean() + loss_FL = loss_FL.mean() + + return loss_T, loss_R, loss_FL + + +def compute_point_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): + """ + Compute point loss. + + Args: + predictions: Dict containing 'world_points' and 'world_points_conf' + batch: Dict containing ground truth 'world_points' and 'point_masks' + gamma: Weight for confidence loss + alpha: Weight for confidence regularization + gradient_loss_fn: Type of gradient loss to apply + valid_range: Quantile range for outlier filtering + """ + pred_points = predictions['world_points'] + pred_points_conf = predictions['world_points_conf'] + gt_points = batch['world_points'] + gt_points_mask = batch['point_masks'] + + gt_points = check_and_fix_inf_nan(gt_points, "gt_points") + + if gt_points_mask.sum() < 100: + # If there are less than 100 valid points, skip this batch + dummy_loss = (0.0 * pred_points).mean() + loss_dict = {f"loss_conf_point": dummy_loss, + f"loss_reg_point": dummy_loss, + f"loss_grad_point": dummy_loss,} + return loss_dict + + # Compute confidence-weighted regression loss with optional gradient loss + loss_conf, loss_grad, loss_reg = regression_loss(pred_points, gt_points, gt_points_mask, conf=pred_points_conf, + gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) + + loss_dict = { + f"loss_conf_point": loss_conf, + f"loss_reg_point": loss_reg, + f"loss_grad_point": loss_grad, + } + + return loss_dict + + +def compute_depth_loss(predictions, batch, gamma=1.0, alpha=0.2, gradient_loss_fn = None, valid_range=-1, **kwargs): + """ + Compute depth loss. + + Args: + predictions: Dict containing 'depth' and 'depth_conf' + batch: Dict containing ground truth 'depths' and 'point_masks' + gamma: Weight for confidence loss + alpha: Weight for confidence regularization + gradient_loss_fn: Type of gradient loss to apply + valid_range: Quantile range for outlier filtering + """ + pred_depth = predictions['depth'] + pred_depth_conf = predictions['depth_conf'] + + gt_depth = batch['depths'] + gt_depth = check_and_fix_inf_nan(gt_depth, "gt_depth") + gt_depth = gt_depth[..., None] # (B, H, W, 1) + gt_depth_mask = batch['point_masks'].clone() # 3D points derived from depth map, so we use the same mask + + if gt_depth_mask.sum() < 100: + # If there are less than 100 valid points, skip this batch + dummy_loss = (0.0 * pred_depth).mean() + loss_dict = {f"loss_conf_depth": dummy_loss, + f"loss_reg_depth": dummy_loss, + f"loss_grad_depth": dummy_loss,} + return loss_dict + + # NOTE: we put conf inside regression_loss so that we can also apply conf loss to the gradient loss in a multi-scale manner + # this is hacky, but very easier to implement + loss_conf, loss_grad, loss_reg = regression_loss(pred_depth, gt_depth, gt_depth_mask, conf=pred_depth_conf, + gradient_loss_fn=gradient_loss_fn, gamma=gamma, alpha=alpha, valid_range=valid_range) + + loss_dict = { + f"loss_conf_depth": loss_conf, + f"loss_reg_depth": loss_reg, + f"loss_grad_depth": loss_grad, + } + + return loss_dict + + +def regression_loss(pred, gt, mask, conf=None, gradient_loss_fn=None, gamma=1.0, alpha=0.2, valid_range=-1): + """ + Core regression loss function with confidence weighting and optional gradient loss. + + Computes: + 1. gamma * ||pred - gt||^2 * conf - alpha * log(conf) + 2. Optional gradient loss + + Args: + pred: (B, S, H, W, C) predicted values + gt: (B, S, H, W, C) ground truth values + mask: (B, S, H, W) valid pixel mask + conf: (B, S, H, W) confidence weights (optional) + gradient_loss_fn: Type of gradient loss ("normal", "grad", etc.) + gamma: Weight for confidence loss + alpha: Weight for confidence regularization + valid_range: Quantile range for outlier filtering + + Returns: + loss_conf: Confidence-weighted loss + loss_grad: Gradient loss (0 if not specified) + loss_reg: Regular L2 loss + """ + bb, ss, hh, ww, nc = pred.shape + + # Compute L2 distance between predicted and ground truth points + loss_reg = torch.norm(gt[mask] - pred[mask], dim=-1) + loss_reg = check_and_fix_inf_nan(loss_reg, "loss_reg") + + # Confidence-weighted loss: gamma * loss * conf - alpha * log(conf) + # This encourages the model to be confident on easy examples and less confident on hard ones + loss_conf = gamma * loss_reg * conf[mask] - alpha * torch.log(conf[mask]) + loss_conf = check_and_fix_inf_nan(loss_conf, "loss_conf") + + # Initialize gradient loss + loss_grad = 0 + + # Prepare confidence for gradient loss if needed + if "conf" in gradient_loss_fn: + to_feed_conf = conf.reshape(bb*ss, hh, ww) + else: + to_feed_conf = None + + # Compute gradient loss if specified for spatial smoothness + if "normal" in gradient_loss_fn: + # Surface normal-based gradient loss + loss_grad = gradient_loss_multi_scale_wrapper( + pred.reshape(bb*ss, hh, ww, nc), + gt.reshape(bb*ss, hh, ww, nc), + mask.reshape(bb*ss, hh, ww), + gradient_loss_fn=normal_loss, + scales=3, + conf=to_feed_conf, + ) + elif "grad" in gradient_loss_fn: + # Standard gradient-based loss + loss_grad = gradient_loss_multi_scale_wrapper( + pred.reshape(bb*ss, hh, ww, nc), + gt.reshape(bb*ss, hh, ww, nc), + mask.reshape(bb*ss, hh, ww), + gradient_loss_fn=gradient_loss, + conf=to_feed_conf, + ) + + # Process confidence-weighted loss + if loss_conf.numel() > 0: + # Filter out outliers using quantile-based thresholding + if valid_range>0: + loss_conf = filter_by_quantile(loss_conf, valid_range) + + loss_conf = check_and_fix_inf_nan(loss_conf, f"loss_conf_depth") + loss_conf = loss_conf.mean() + else: + loss_conf = (0.0 * pred).mean() + + # Process regular regression loss + if loss_reg.numel() > 0: + # Filter out outliers using quantile-based thresholding + if valid_range>0: + loss_reg = filter_by_quantile(loss_reg, valid_range) + + loss_reg = check_and_fix_inf_nan(loss_reg, f"loss_reg_depth") + loss_reg = loss_reg.mean() + else: + loss_reg = (0.0 * pred).mean() + + return loss_conf, loss_grad, loss_reg + + +def gradient_loss_multi_scale_wrapper(prediction, target, mask, scales=4, gradient_loss_fn = None, conf=None): + """ + Multi-scale gradient loss wrapper. Applies gradient loss at multiple scales by subsampling the input. + This helps capture both fine and coarse spatial structures. + + Args: + prediction: (B, H, W, C) predicted values + target: (B, H, W, C) ground truth values + mask: (B, H, W) valid pixel mask + scales: Number of scales to use + gradient_loss_fn: Gradient loss function to apply + conf: (B, H, W) confidence weights (optional) + """ + total = 0 + for scale in range(scales): + step = pow(2, scale) # Subsample by 2^scale + + total += gradient_loss_fn( + prediction[:, ::step, ::step], + target[:, ::step, ::step], + mask[:, ::step, ::step], + conf=conf[:, ::step, ::step] if conf is not None else None + ) + + total = total / scales + return total + + +def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None, gamma=1.0, alpha=0.2): + """ + Surface normal-based loss for geometric consistency. + + Computes surface normals from 3D point maps using cross products of neighboring points, + then measures the angle between predicted and ground truth normals. + + Args: + prediction: (B, H, W, 3) predicted 3D coordinates/points + target: (B, H, W, 3) ground-truth 3D coordinates/points + mask: (B, H, W) valid pixel mask + cos_eps: Epsilon for numerical stability in cosine computation + conf: (B, H, W) confidence weights (optional) + gamma: Weight for confidence loss + alpha: Weight for confidence regularization + """ + # Convert point maps to surface normals using cross products + pred_normals, pred_valids = point_map_to_normal(prediction, mask, eps=cos_eps) + gt_normals, gt_valids = point_map_to_normal(target, mask, eps=cos_eps) + + # Only consider regions where both predicted and GT normals are valid + all_valid = pred_valids & gt_valids # shape: (4, B, H, W) + + # Early return if not enough valid points + divisor = torch.sum(all_valid) + if divisor < 10: + return 0 + + # Extract valid normals + pred_normals = pred_normals[all_valid].clone() + gt_normals = gt_normals[all_valid].clone() + + # Compute cosine similarity between corresponding normals + dot = torch.sum(pred_normals * gt_normals, dim=-1) + + # Clamp dot product to [-1, 1] for numerical stability + dot = torch.clamp(dot, -1 + cos_eps, 1 - cos_eps) + + # Compute loss as 1 - cos(theta), instead of arccos(dot) for numerical stability + loss = 1 - dot + + # Return mean loss if we have enough valid points + if loss.numel() < 10: + return 0 + else: + loss = check_and_fix_inf_nan(loss, "normal_loss") + + if conf is not None: + # Apply confidence weighting + conf = conf[None, ...].expand(4, -1, -1, -1) + conf = conf[all_valid].clone() + + loss = gamma * loss * conf - alpha * torch.log(conf) + return loss.mean() + else: + return loss.mean() + + +def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=0.2): + """ + Gradient-based loss. Computes the L1 difference between adjacent pixels in x and y directions. + + Args: + prediction: (B, H, W, C) predicted values + target: (B, H, W, C) ground truth values + mask: (B, H, W) valid pixel mask + conf: (B, H, W) confidence weights (optional) + gamma: Weight for confidence loss + alpha: Weight for confidence regularization + """ + # Expand mask to match prediction channels + mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) + M = torch.sum(mask, (1, 2, 3)) + + # Compute difference between prediction and target + diff = prediction - target + diff = torch.mul(mask, diff) + + # Compute gradients in x direction (horizontal) + grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) + mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) + grad_x = torch.mul(mask_x, grad_x) + + # Compute gradients in y direction (vertical) + grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) + mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) + grad_y = torch.mul(mask_y, grad_y) + + # Clamp gradients to prevent outliers + grad_x = grad_x.clamp(max=100) + grad_y = grad_y.clamp(max=100) + + # Apply confidence weighting if provided + if conf is not None: + conf = conf[..., None].expand(-1, -1, -1, prediction.shape[-1]) + conf_x = conf[:, :, 1:] + conf_y = conf[:, 1:, :] + + grad_x = gamma * grad_x * conf_x - alpha * torch.log(conf_x) + grad_y = gamma * grad_y * conf_y - alpha * torch.log(conf_y) + + # Sum gradients and normalize by number of valid pixels + grad_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) + divisor = torch.sum(M) + + if divisor == 0: + return 0 + else: + grad_loss = torch.sum(grad_loss) / divisor + + return grad_loss + + +def point_map_to_normal(point_map, mask, eps=1e-6): + """ + Convert 3D point map to surface normal vectors using cross products. + + Computes normals by taking cross products of neighboring point differences. + Uses 4 different cross-product directions for robustness. + + Args: + point_map: (B, H, W, 3) 3D points laid out in a 2D grid + mask: (B, H, W) valid pixels (bool) + eps: Epsilon for numerical stability in normalization + + Returns: + normals: (4, B, H, W, 3) normal vectors for each of the 4 cross-product directions + valids: (4, B, H, W) corresponding valid masks + """ + with torch.cuda.amp.autocast(enabled=False): + # Pad inputs to avoid boundary issues + padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + pts = F.pad(point_map.permute(0, 3, 1, 2), (1,1,1,1), mode='constant', value=0).permute(0, 2, 3, 1) + + # Get neighboring points for each pixel + center = pts[:, 1:-1, 1:-1, :] # B,H,W,3 + up = pts[:, :-2, 1:-1, :] + left = pts[:, 1:-1, :-2 , :] + down = pts[:, 2:, 1:-1, :] + right = pts[:, 1:-1, 2:, :] + + # Compute direction vectors from center to neighbors + up_dir = up - center + left_dir = left - center + down_dir = down - center + right_dir = right - center + + # Compute four cross products for different normal directions + n1 = torch.cross(up_dir, left_dir, dim=-1) # up x left + n2 = torch.cross(left_dir, down_dir, dim=-1) # left x down + n3 = torch.cross(down_dir, right_dir, dim=-1) # down x right + n4 = torch.cross(right_dir,up_dir, dim=-1) # right x up + + # Validity masks - require both direction pixels to be valid + v1 = padded_mask[:, :-2, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, :-2] + v2 = padded_mask[:, 1:-1, :-2 ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 2:, 1:-1] + v3 = padded_mask[:, 2:, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, 2:] + v4 = padded_mask[:, 1:-1, 2: ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, :-2, 1:-1] + + # Stack normals and validity masks + normals = torch.stack([n1, n2, n3, n4], dim=0) # shape [4, B, H, W, 3] + valids = torch.stack([v1, v2, v3, v4], dim=0) # shape [4, B, H, W] + + # Normalize normal vectors + normals = F.normalize(normals, p=2, dim=-1, eps=eps) + + return normals, valids + + +def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard_max=100): + """ + Filter loss tensor by keeping only values below a certain quantile threshold. + + This helps remove outliers that could destabilize training. + + Args: + loss_tensor: Tensor containing loss values + valid_range: Float between 0 and 1 indicating the quantile threshold + min_elements: Minimum number of elements required to apply filtering + hard_max: Maximum allowed value for any individual loss + + Returns: + Filtered and clamped loss tensor + """ + if loss_tensor.numel() <= min_elements: + # Too few elements, just return as-is + return loss_tensor + + # Randomly sample if tensor is too large to avoid memory issues + if loss_tensor.numel() > 100000000: + # Flatten and randomly select 1M elements + indices = torch.randperm(loss_tensor.numel(), device=loss_tensor.device)[:1_000_000] + loss_tensor = loss_tensor.view(-1)[indices] + + # First clamp individual values to prevent extreme outliers + loss_tensor = loss_tensor.clamp(max=hard_max) + + # Compute quantile threshold + quantile_thresh = torch_quantile(loss_tensor.detach(), valid_range) + quantile_thresh = min(quantile_thresh, hard_max) + + # Apply quantile filtering if enough elements remain + quantile_mask = loss_tensor < quantile_thresh + if quantile_mask.sum() > min_elements: + return loss_tensor[quantile_mask] + return loss_tensor + + +def torch_quantile( + input, + q, + dim = None, + keepdim: bool = False, + *, + interpolation: str = "nearest", + out: torch.Tensor = None, +) -> torch.Tensor: + """Better torch.quantile for one SCALAR quantile. + + Using torch.kthvalue. Better than torch.quantile because: + - No 2**24 input size limit (pytorch/issues/67592), + - Much faster, at least on big input sizes. + + Arguments: + input (torch.Tensor): See torch.quantile. + q (float): See torch.quantile. Supports only scalar input + currently. + dim (int | None): See torch.quantile. + keepdim (bool): See torch.quantile. Supports only False + currently. + interpolation: {"nearest", "lower", "higher"} + See torch.quantile. + out (torch.Tensor | None): See torch.quantile. Supports only + None currently. + """ + # https://github.com/pytorch/pytorch/issues/64947 + # Sanitization: q + try: + q = float(q) + assert 0 <= q <= 1 + except Exception: + raise ValueError(f"Only scalar input 0<=q<=1 is currently supported (got {q})!") + + # Handle dim=None case + if dim_was_none := dim is None: + dim = 0 + input = input.reshape((-1,) + (1,) * (input.ndim - 1)) + + # Set interpolation method + if interpolation == "nearest": + inter = round + elif interpolation == "lower": + inter = floor + elif interpolation == "higher": + inter = ceil + else: + raise ValueError( + "Supported interpolations currently are {'nearest', 'lower', 'higher'} " + f"(got '{interpolation}')!" + ) + + # Validate out parameter + if out is not None: + raise ValueError(f"Only None value is currently supported for out (got {out})!") + + # Compute k-th value + k = inter(q * (input.shape[dim] - 1)) + 1 + out = torch.kthvalue(input, k, dim, keepdim=True, out=out)[0] + + # Handle keepdim and dim=None cases + if keepdim: + return out + if dim_was_none: + return out.squeeze() + else: + return out.squeeze(dim) + + return out + + +######################################################################################## +######################################################################################## + +# Dirty code for tracking loss: + +######################################################################################## +######################################################################################## + +''' +def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch): + """Compute tracking losses using sequence_loss""" + gt_tracks = batch["tracks"] # B, S, N, 2 + gt_track_vis_mask = batch["track_vis_mask"] # B, S, N + + # if self.training and hasattr(self, "train_query_points"): + train_query_points = coord_preds[-1].shape[2] + gt_tracks = gt_tracks[:, :, :train_query_points] + gt_tracks = check_and_fix_inf_nan(gt_tracks, "gt_tracks", hard_max=None) + + gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points] + + # Create validity mask that filters out tracks not visible in first frame + valids = torch.ones_like(gt_track_vis_mask) + mask = gt_track_vis_mask[:, 0, :] == True + valids = valids * mask.unsqueeze(1) + + + + if not valids.any(): + print("No valid tracks found in first frame") + print("seq_name: ", batch["seq_name"]) + print("ids: ", batch["ids"]) + print("time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) + + dummy_coord = coord_preds[0].mean() * 0 # keeps graph & grads + dummy_vis = vis_scores.mean() * 0 + if conf_scores is not None: + dummy_conf = conf_scores.mean() * 0 + else: + dummy_conf = 0 + return dummy_coord, dummy_vis, dummy_conf # three scalar zeros + + + # Compute tracking loss using sequence_loss + track_loss = sequence_loss( + flow_preds=coord_preds, + flow_gt=gt_tracks, + vis=gt_track_vis_mask, + valids=valids, + **self.loss_kwargs + ) + + vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float()) + + vis_loss = check_and_fix_inf_nan(vis_loss, "vis_loss", hard_max=None) + + + # within 3 pixels + if conf_scores is not None: + gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3 + conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float()) + conf_loss = check_and_fix_inf_nan(conf_loss, "conf_loss", hard_max=None) + else: + conf_loss = 0 + + return track_loss, vis_loss, conf_loss + + + +def reduce_masked_mean(x, mask, dim=None, keepdim=False): + for a, b in zip(x.size(), mask.size()): + assert a == b + prod = x * mask + + if dim is None: + numer = torch.sum(prod) + denom = torch.sum(mask) + else: + numer = torch.sum(prod, dim=dim, keepdim=keepdim) + denom = torch.sum(mask, dim=dim, keepdim=keepdim) + + mean = numer / denom.clamp(min=1) + mean = torch.where(denom > 0, + mean, + torch.zeros_like(mean)) + return mean + + +def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs): + """Loss function defined over sequence of flow predictions""" + B, S, N, D = flow_gt.shape + assert D == 2 + B, S1, N = vis.shape + B, S2, N = valids.shape + assert S == S1 + assert S == S2 + n_predictions = len(flow_preds) + flow_loss = 0.0 + + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + flow_pred = flow_preds[i] + + i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2 + i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_{i}", hard_max=None) + + i_loss = torch.mean(i_loss, dim=3) # B, S, N + + # Combine valids and vis for per-frame valid masking. + combined_mask = torch.logical_and(valids, vis) + + num_valid_points = combined_mask.sum() + + if vis_aware: + combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself. + flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask) + else: + if num_valid_points > 2: + i_loss = i_loss[combined_mask] + flow_loss += i_weight * i_loss.mean() + else: + i_loss = check_and_fix_inf_nan(i_loss, f"i_loss_iter_safe_check_{i}", hard_max=None) + flow_loss += 0 * i_loss.mean() + + # Avoid division by zero if n_predictions is 0 (though it shouldn't be). + if n_predictions > 0: + flow_loss = flow_loss / n_predictions + + return flow_loss +''' + + diff --git a/vggt/training/train_utils/__init__.py b/vggt/training/train_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/training/train_utils/checkpoint.py b/vggt/training/train_utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..590da0d76e65e6c3e5bdeb6993ba3c99375dba92 --- /dev/null +++ b/vggt/training/train_utils/checkpoint.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import ( + Any, + Dict, + List, +) + +import torch +import torch.nn as nn +import os +from iopath.common.file_io import g_pathmgr +from wcmatch import fnmatch + + + + + +# ------------------------------------------------------------ +# Glob‑matching flags (behave like the Unix shell) +# ------------------------------------------------------------ +GLOB_FLAGS = ( + fnmatch.CASE # case‑sensitive + | fnmatch.DOTMATCH # '*' also matches '.' + | fnmatch.EXTMATCH # extended patterns like *(foo|bar) + | fnmatch.SPLIT # "pat1|pat2" works out‑of‑the‑box +) + + + + +class DDPCheckpointSaver: + def __init__( + self, + checkpoint_folder: str, + checkpoint_names: List[str], + rank: int, + epoch: int, + ): + super().__init__() + self.checkpoint_folder = checkpoint_folder + self.checkpoint_names = checkpoint_names + self.worker_id = rank + self.epoch = epoch + + def save_checkpoint( + self, + model: nn.Module, + **kwargs: Any, + ) -> None: + checkpoint = dict(**kwargs) + checkpoint["model"] = model.state_dict() + + if self.worker_id == 0: + for ckpt_name in self.checkpoint_names: + checkpoint_path = os.path.join( + self.checkpoint_folder, f"{ckpt_name}.pt" + ) + logging.info( + f"Saving checkpoint at epoch {self.epoch} to {checkpoint_path}" + ) + robust_torch_save(checkpoint, checkpoint_path) + + + +def robust_torch_save(checkpoint: Dict[str, Any], checkpoint_path: str) -> None: + """ + A more robust version of torch.save that works better with preemptions + and corruptions if a job is preempted during save. + """ + # Move the existing checkpoint to a backup location + backup_checkpoint_path = checkpoint_path + ".bak" + backup_checkpoint_path_saved = False + if g_pathmgr.exists(checkpoint_path): + assert not g_pathmgr.exists( + backup_checkpoint_path + ), f"this should not exist... {backup_checkpoint_path}" + g_pathmgr.mv(checkpoint_path, backup_checkpoint_path) + backup_checkpoint_path_saved = True + # Save the checkpoint + with g_pathmgr.open(checkpoint_path, "wb") as f: + torch.save(checkpoint, f) + # Remove the backup checkpoint + if backup_checkpoint_path_saved: + g_pathmgr.rm(backup_checkpoint_path) \ No newline at end of file diff --git a/vggt/training/train_utils/distributed.py b/vggt/training/train_utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..af61e2697324589e535434fad11ecca0035c769e --- /dev/null +++ b/vggt/training/train_utils/distributed.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import time +import torch + +def get_machine_local_and_dist_rank(): + """ + Get the distributed and local rank of the current gpu. + """ + local_rank = int(os.environ.get("LOCAL_RANK", None)) + distributed_rank = int(os.environ.get("RANK", None)) + assert ( + local_rank is not None and distributed_rank is not None + ), "Please the set the RANK and LOCAL_RANK environment variables." + return local_rank, distributed_rank diff --git a/vggt/training/train_utils/freeze.py b/vggt/training/train_utils/freeze.py new file mode 100644 index 0000000000000000000000000000000000000000..563d434879cf868cde1afd3bf40f10ea0f6e09f7 --- /dev/null +++ b/vggt/training/train_utils/freeze.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from wcmatch import fnmatch +from functools import wraps +from typing import List + +import torch.nn as nn + +# ------------------------------------------------------------ +# Glob‑matching flags (behave like the Unix shell) +# ------------------------------------------------------------ +GLOB_FLAGS = ( + fnmatch.CASE # case‑sensitive + | fnmatch.DOTMATCH # '*' also matches '.' + | fnmatch.EXTMATCH # extended patterns like *(foo|bar) + | fnmatch.SPLIT # "pat1|pat2" works out‑of‑the‑box +) + + +def freeze_modules(model: nn.Module, patterns: List[str], recursive: bool = True) -> nn.Module: + """Freeze (stop training) parts of *model* whose *name* matches *patterns*. + + Parameters + ---------- + model : nn.Module + The complete model you are working with. + patterns : list[str] + Glob patterns to match sub‑module names. Example: ``["encoder.*", "cls_head"]`` + recursive : bool, default = True + • ``True`` → also freeze every child of a matched module. + • ``False`` → freeze only the matched module itself. + + Returns + ------- + nn.Module + The same model object, now with some parts frozen. + + Example + ------- + >>> freeze_modules(model, ["encoder.*", "decoder.layer1"], recursive=True) + """ + matched: set[str] = set() + + for name, mod in model.named_modules(): + # does *name* match ANY user pattern? + if any(fnmatch.fnmatch(name, p, flags=GLOB_FLAGS) for p in patterns): + matched.add(name) + _freeze(mod, recursive) + + _check_every_pattern_used(matched, patterns) + return model + + +# ------------------------------------------------------------ +# helpers +# ------------------------------------------------------------ + +def _freeze(mod: nn.Module, recursive: bool) -> None: + """Put *mod* in eval mode and lock its parameters.""" + + if recursive: + mod.eval() # affects the whole subtree + else: + mod.training = False # only this exact module + + original_train = mod.train + + @wraps(original_train) + def locked_train(mode: bool = True): + if recursive: + return original_train(False) # ignore user's *mode* + out = original_train(mode) # children follow user's choice + out.training = False # but this module stays frozen + return out + + mod.train = locked_train # type: ignore[attr-defined] + + param_iter = ( + mod.parameters() # default recurse=True + if recursive + else mod.parameters(recurse=False) + ) + for p in param_iter: + p.requires_grad = False + + +def _check_every_pattern_used(matched_names: set[str], patterns: List[str]): + unused = [p for p in patterns if not any(fnmatch.fnmatch(n, p, flags=GLOB_FLAGS) + for n in matched_names)] + if unused: + raise ValueError(f"These patterns matched nothing: {unused}") diff --git a/vggt/training/train_utils/general.py b/vggt/training/train_utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0c88a05c9d3e7e7c0f09c1ab8c843b7285b29f --- /dev/null +++ b/vggt/training/train_utils/general.py @@ -0,0 +1,369 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import os +import math +import random +import numpy as np +from typing import Union, Optional +import logging +from iopath.common.file_io import g_pathmgr +import torch.distributed as dist +from pathlib import Path +from typing import Dict, Iterable, List + + + +from collections import defaultdict +from dataclasses import fields, is_dataclass +from typing import Any, Mapping, Protocol, runtime_checkable + + + + +def check_and_fix_inf_nan(input_tensor, loss_name="default", hard_max=100): + """ + Checks if 'input_tensor' contains inf or nan values and clamps extreme values. + + Args: + input_tensor (torch.Tensor): The loss tensor to check and fix. + loss_name (str): Name of the loss (for diagnostic prints). + hard_max (float, optional): Maximum absolute value allowed. Values outside + [-hard_max, hard_max] will be clamped. If None, + no clamping is performed. Defaults to 100. + """ + if input_tensor is None: + return input_tensor + + # Check for inf/nan values + has_inf_nan = torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any() + if has_inf_nan: + logging.warning(f"Tensor {loss_name} contains inf or nan values. Replacing with zeros.") + input_tensor = torch.where( + torch.isnan(input_tensor) | torch.isinf(input_tensor), + torch.zeros_like(input_tensor), + input_tensor + ) + + # Apply hard clamping if specified + if hard_max is not None: + input_tensor = torch.clamp(input_tensor, min=-hard_max, max=hard_max) + + return input_tensor + + +def get_resume_checkpoint(checkpoint_save_dir): + if not g_pathmgr.isdir(checkpoint_save_dir): + return None + ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") + if not g_pathmgr.isfile(ckpt_file): + return None + + return ckpt_file + +class DurationMeter: + def __init__(self, name, device, fmt=":f"): + self.name = name + self.device = device + self.fmt = fmt + self.val = 0 + + def reset(self): + self.val = 0 + + def update(self, val): + self.val = val + + def add(self, val): + self.val += val + + def __str__(self): + return f"{self.name}: {human_readable_time(self.val)}" + + +def human_readable_time(time_seconds): + time = int(time_seconds) + minutes, seconds = divmod(time, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + return f"{days:02}d {hours:02}h {minutes:02}m" + + + +class ProgressMeter: + def __init__(self, num_batches, meters, real_meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.real_meters = real_meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + entries += [ + " | ".join( + [ + f"{os.path.join(name, subname)}: {val:.4f}" + for subname, val in meter.compute().items() + ] + ) + for name, meter in self.real_meters.items() + ] + logging.info(" | ".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + + +@runtime_checkable +class _CopyableData(Protocol): + def to(self, device: torch.device, *args: Any, **kwargs: Any): + """Copy data to the specified device""" + ... + + +def _is_named_tuple(x) -> bool: + return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") + + +def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): + """Function that recursively copies data to a torch.device. + + Args: + data: The data to copy to device + device: The device to which the data should be copied + args: positional arguments that will be passed to the `to` call + kwargs: keyword arguments that will be passed to the `to` call + + Returns: + The data on the correct device + """ + + if _is_named_tuple(data): + return type(data)( + **copy_data_to_device(data._asdict(), device, *args, **kwargs) + ) + elif isinstance(data, (list, tuple)): + return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) + elif isinstance(data, defaultdict): + return type(data)( + data.default_factory, + { + k: copy_data_to_device(v, device, *args, **kwargs) + for k, v in data.items() + }, + ) + elif isinstance(data, Mapping) and not is_dataclass(data): # handing FrameData-like things + return type(data)( + { + k: copy_data_to_device(v, device, *args, **kwargs) + for k, v in data.items() + } + ) + elif is_dataclass(data) and not isinstance(data, type): + new_data_class = type(data)( + **{ + field.name: copy_data_to_device( + getattr(data, field.name), device, *args, **kwargs + ) + for field in fields(data) + if field.init + } + ) + for field in fields(data): + if not field.init: + setattr( + new_data_class, + field.name, + copy_data_to_device( + getattr(data, field.name), device, *args, **kwargs + ), + ) + return new_data_class + elif isinstance(data, _CopyableData): + return data.to(device, *args, **kwargs) + return data + + + +def safe_makedirs(path: str): + if not path: + logging.warning("safe_makedirs called with an empty path. No operation performed.") + return False + + try: + os.makedirs(path, exist_ok=True) + return True + except OSError as e: + logging.error(f"Failed to create directory '{path}'. Reason: {e}") + raise + except Exception as e: + # Catch any other unexpected errors. + logging.error(f"An unexpected error occurred while creating directory '{path}'. Reason: {e}") + raise + + + +def set_seeds(seed_value, max_epochs, dist_rank): + """ + Set the python random, numpy and torch seed for each gpu. Also set the CUDA + seeds if the CUDA is available. This ensures deterministic nature of the training. + """ + seed_value = (seed_value + dist_rank) * max_epochs + logging.info(f"GPU SEED: {seed_value}") + random.seed(seed_value) + np.random.seed(seed_value) + torch.manual_seed(seed_value) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed_value) + torch.cuda.manual_seed_all(seed_value) # for multi-GPU + + + + +def log_env_variables(): + env_keys = sorted(list(os.environ.keys())) + st = "" + for k in env_keys: + v = os.environ[k] + st += f"{k}={v}\n" + logging.info("Logging ENV_VARIABLES") + logging.info(st) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + + +class AverageMeter: + """Computes and stores the average and current value. + Args: + name (str): Name of the metric being tracked + device (torch.device, optional): Device for tensor operations. Defaults to None. + fmt (str): Format string for displaying values. Defaults to ":f" + """ + + def __init__(self, name: str, device: Optional[torch.device] = None, fmt: str = ":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, val, n=1): + if n <= 0: + raise ValueError(f"n must be positive, got {n}") + + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count if self.count > 0 else 0.0 + + def __str__(self) -> str: + """String representation showing current and average values.""" + fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + @property + def value(self) -> float: + """Get the current value.""" + return self.val + + @property + def average(self) -> float: + """Get the running average.""" + return self.avg + +################# + + +_UNITS = ('', ' K', ' M', ' B', ' T') # U+202F = thin-space for nicer look + +def pretty_int(n: int) -> str: + """Abbreviate a non-negative integer (0 → 0, 12_345 → '12.3 K').""" + assert n >= 0, 'pretty_int() expects a non-negative int' + if n < 1_000: + return f'{n:,}' + exp = int(math.log10(n) // 3) # group of 3 digits + exp = min(exp, len(_UNITS) - 1) # cap at trillions + value = n / 10 ** (3 * exp) + return f'{value:.1f}'.rstrip('0').rstrip('.') + _UNITS[exp] + + +def model_summary(model: torch.nn.Module, + *, + log_file = None, + prefix: str = '') -> None: + """ + Print / save a compact parameter summary. + + Args + ---- + model : The PyTorch nn.Module to inspect. + log_file : Optional path – if given, the full `str(model)` and per-parameter + lists are written there (three separate *.txt files). + prefix : Optional string printed at the beginning of every log line + (handy when several models share the same stdout). + """ + if get_rank(): # only rank-0 prints + return + + # --- counts ------------------------------------------------------------- + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + frozen = total - trainable + + print(prefix + '='*60) + print(prefix + f'Model type : {model.__class__.__name__}') + print(prefix + f'Total : {pretty_int(total)} parameters') + print(prefix + f' trainable: {pretty_int(trainable)}') + print(prefix + f' frozen : {pretty_int(frozen)}') + print(prefix + '='*60) + + # --- optional file dump ------------------------------------------------- + if log_file is None: + return + + log_file = Path(log_file) + log_file.write_text(str(model)) # full architecture + + # two extra detailed lists + def _dump(names: Iterable[str], fname: str): + """Write a formatted per-parameter list to *log_file.with_name(fname)*.""" + with open(log_file.with_name(fname), 'w') as f: + for n in names: + p = dict(model.named_parameters())[n] + shape = str(tuple(p.shape)) + f.write(f'{n:<60s} {shape:<20} {p.numel()}\n') + + named = dict(model.named_parameters()) + _dump([n for n,p in named.items() if p.requires_grad], 'trainable.txt') + _dump([n for n,p in named.items() if not p.requires_grad], 'frozen.txt') + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + diff --git a/vggt/training/train_utils/gradient_clip.py b/vggt/training/train_utils/gradient_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb9dffb7e7993bb9c85f56a9091eeb7bbac776a --- /dev/null +++ b/vggt/training/train_utils/gradient_clip.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from typing import Union, Optional + + +class GradientClipper: + """ + Gradient clipping utils that works for both FSDP and DDP with support for different + clipping configurations for different parts of the model. + """ + def __init__(self, configs, *args, **kwargs): + """ + Args: + configs: List of dictionaries, each containing: + - module_name: str or list of str, module names to apply clipping to + - max_norm: float, maximum norm for gradient clipping + - norm_type: int, type of norm (default: 2) + """ + self.configs = [] + self.params_to_clip_by_config = None + self.is_initialized = False + + for config in configs: + module_names = config['module_name'] + if isinstance(module_names, str): + module_names = [module_names] + + self.configs.append({ + 'module_names': module_names, + 'max_norm': float(config['max_norm']) if config['max_norm'] is not None else None, + 'norm_type': config.get('norm_type', 2) + }) + + def setup_clipping(self, model: nn.Module) -> None: + """ + Set up gradient clipping by finding all parameters that should be clipped + based on module names and validating that all parameters are covered. + + This should be called once at the beginning of training. + + Args: + model: The model to set up gradient clipping for + """ + # First, collect all parameters that should be clipped based on module names + params_to_clip_by_config = [] + all_clipped_params = set() + + for config in self.configs: + current_config_params = [] + for name, param in model.named_parameters(): + if param.requires_grad: + for module_name in config['module_names']: + if module_name in name: + current_config_params.append(param) + all_clipped_params.add(param) + break + params_to_clip_by_config.append((config, current_config_params)) + + # Check for remaining parameters + remaining_params = [] + for name, param in model.named_parameters(): + if param.requires_grad and param not in all_clipped_params: + remaining_params.append(param) + + if len(remaining_params) > 0: + print(f"Found {len(remaining_params)} parameters that won't be clipped") + print(remaining_params) + raise ValueError("Some parameters are not configured for gradient clipping") + + # Store the computed parameters + self.params_to_clip_by_config = params_to_clip_by_config + self.is_initialized = True + + def __call__(self, model: nn.Module) -> Optional[torch.Tensor]: + """ + Perform gradient clipping using the pre-computed parameter groups. + + Args: + model: The model (not used, kept for backward compatibility) + + Returns: + Dictionary of gradient norms for each configuration + """ + if not self.is_initialized: + raise RuntimeError("GradientClipper must be initialized with setup_clipping() before use") + + grad_norms = {} + for config, params_to_clip in self.params_to_clip_by_config: + if not params_to_clip or config['max_norm'] is None: + continue + + grad_norm = nn.utils.clip_grad_norm_( + params_to_clip, + max_norm=config['max_norm'], + norm_type=config['norm_type'] + ) + + if grad_norm is None: + continue + + grad_norms[",".join(config['module_names'])] = grad_norm.item() + + return grad_norms diff --git a/vggt/training/train_utils/logging.py b/vggt/training/train_utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..a268b17751e650572e10161b988a7dfd5f4e416e --- /dev/null +++ b/vggt/training/train_utils/logging.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import copy +import sys +import atexit + +import functools +from .general import safe_makedirs +from iopath.common.file_io import g_pathmgr + + +# cache the opened file object, so that different calls +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + log_buffer_kb = 1 * 1024 # 1KB + io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) + atexit.register(io.close) + return io + + + +def setup_logging( + name, + output_dir=None, + rank=0, + log_level_primary="INFO", + log_level_secondary="ERROR", + all_ranks: bool = False, +): + """ + Setup various logging streams: stdout and file handlers. + For file handlers, we only setup for the master gpu. + """ + global LOGGING_STATE + LOGGING_STATE = copy.deepcopy(locals()) + + # get the filename if we want to log to the file as well + log_filename = None + if output_dir: + safe_makedirs(output_dir) + if rank == 0: + log_filename = f"{output_dir}/log.txt" + elif all_ranks: + log_filename = f"{output_dir}/log_{rank}.txt" + + logger = logging.getLogger(name) + logger.setLevel(log_level_primary) + + # create formatter + FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" + formatter = logging.Formatter(FORMAT) + + # clean up any pre-existing handlers + for h in logger.handlers: + logger.removeHandler(h) + logger.root.handlers = [] + logging.root.handlers = [] + + # setup the console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + if rank == 0: + console_handler.setLevel(log_level_primary) + else: + console_handler.setLevel(log_level_secondary) + logger.addHandler(console_handler) + + # we log to file as well if user wants + if log_filename is not None: + file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) + file_handler.setLevel(log_level_primary) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logging.root = logger diff --git a/vggt/training/train_utils/normalization.py b/vggt/training/train_utils/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..849404706787c7bdb8b06ec49e3fa36e64420eaf --- /dev/null +++ b/vggt/training/train_utils/normalization.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import logging +from typing import Optional, Tuple +from vggt.utils.geometry import closed_form_inverse_se3 +from train_utils.general import check_and_fix_inf_nan + + +def check_valid_tensor(input_tensor: Optional[torch.Tensor], name: str = "tensor") -> None: + """ + Check if a tensor contains NaN or Inf values and log a warning if found. + + Args: + input_tensor: The tensor to check + name: Name of the tensor for logging purposes + """ + if input_tensor is not None: + if torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any(): + logging.warning(f"NaN or Inf found in tensor: {name}") + + +def normalize_camera_extrinsics_and_points_batch( + extrinsics: torch.Tensor, + cam_points: Optional[torch.Tensor] = None, + world_points: Optional[torch.Tensor] = None, + depths: Optional[torch.Tensor] = None, + scale_by_points: bool = True, + point_masks: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Normalize camera extrinsics and corresponding 3D points. + + This function transforms the coordinate system to be centered at the first camera + and optionally scales the scene to have unit average distance. + + Args: + extrinsics: Camera extrinsic matrices of shape (B, S, 3, 4) + cam_points: 3D points in camera coordinates of shape (B, S, H, W, 3) or (*,3) + world_points: 3D points in world coordinates of shape (B, S, H, W, 3) or (*,3) + depths: Depth maps of shape (B, S, H, W) + scale_by_points: Whether to normalize the scale based on point distances + point_masks: Boolean masks for valid points of shape (B, S, H, W) + + Returns: + Tuple containing: + - Normalized camera extrinsics of shape (B, S, 3, 4) + - Normalized camera points (same shape as input cam_points) + - Normalized world points (same shape as input world_points) + - Normalized depths (same shape as input depths) + """ + # Validate inputs + check_valid_tensor(extrinsics, "extrinsics") + check_valid_tensor(cam_points, "cam_points") + check_valid_tensor(world_points, "world_points") + check_valid_tensor(depths, "depths") + + + B, S, _, _ = extrinsics.shape + device = extrinsics.device + assert device == torch.device("cpu") + + + # Convert extrinsics to homogeneous form: (B, N,4,4) + extrinsics_homog = torch.cat( + [ + extrinsics, + torch.zeros((B, S, 1, 4), device=device), + ], + dim=-2, + ) + extrinsics_homog[:, :, -1, -1] = 1.0 + + # first_cam_extrinsic_inv, the inverse of the first camera's extrinsic matrix + # which can be also viewed as the cam_to_world extrinsic matrix + first_cam_extrinsic_inv = closed_form_inverse_se3(extrinsics_homog[:, 0]) + # new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv) + new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv.unsqueeze(1)) # (B,N,4,4) + + + if world_points is not None: + # since we are transforming the world points to the first camera's coordinate system + # we directly use the cam_from_world extrinsic matrix of the first camera + # instead of using the inverse of the first camera's extrinsic matrix + R = extrinsics[:, 0, :3, :3] + t = extrinsics[:, 0, :3, 3] + new_world_points = (world_points @ R.transpose(-1, -2).unsqueeze(1).unsqueeze(2)) + t.unsqueeze(1).unsqueeze(2).unsqueeze(3) + else: + new_world_points = None + + + if scale_by_points: + new_cam_points = cam_points.clone() + new_depths = depths.clone() + + dist = new_world_points.norm(dim=-1) + dist_sum = (dist * point_masks).sum(dim=[1,2,3]) + valid_count = point_masks.sum(dim=[1,2,3]) + avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6) + + + new_world_points = new_world_points / avg_scale.view(-1, 1, 1, 1, 1) + new_extrinsics[:, :, :3, 3] = new_extrinsics[:, :, :3, 3] / avg_scale.view(-1, 1, 1) + if depths is not None: + new_depths = new_depths / avg_scale.view(-1, 1, 1, 1) + if cam_points is not None: + new_cam_points = new_cam_points / avg_scale.view(-1, 1, 1, 1, 1) + else: + return new_extrinsics[:, :, :3], cam_points, new_world_points, depths + + new_extrinsics = new_extrinsics[:, :, :3] # 4x4 -> 3x4 + new_extrinsics = check_and_fix_inf_nan(new_extrinsics, "new_extrinsics", hard_max=None) + new_cam_points = check_and_fix_inf_nan(new_cam_points, "new_cam_points", hard_max=None) + new_world_points = check_and_fix_inf_nan(new_world_points, "new_world_points", hard_max=None) + new_depths = check_and_fix_inf_nan(new_depths, "new_depths", hard_max=None) + + + return new_extrinsics, new_cam_points, new_world_points, new_depths + + + + + diff --git a/vggt/training/train_utils/optimizer.py b/vggt/training/train_utils/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c033a13c72c33833c2a5d3ba4846fc521b47bbd --- /dev/null +++ b/vggt/training/train_utils/optimizer.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import itertools +from typing import Any, Dict, List, Mapping, Iterable, Set, Tuple, Union + +import hydra +import torch +import torch.nn as nn +from torch import Tensor + +# ----------------------------------------------------------------------------- +# Optimizer wrapper +# ----------------------------------------------------------------------------- + +class OptimizerWrapper: + """Wraps a torch.optim.Optimizer and its schedulers (if any).""" + + def __init__(self, optimizer: torch.optim.Optimizer, schedulers=None) -> None: + self.optimizer = optimizer + self.schedulers = schedulers + self._validate_optimizer_schedulers() + self.step_schedulers(0.0) + + # --------------------------------------------------------------------- + # Public API mirroring torch.optim.Optimizer + # --------------------------------------------------------------------- + + def step(self, where: float = 1.0, closure=None): + """Update the optimizer & its schedulers.""" + self.step_schedulers(where) + return self.optimizer.step(closure) + + def zero_grad(self, *args, **kwargs): + return self.optimizer.zero_grad(*args, **kwargs) + + def _validate_optimizer_schedulers(self): + if self.schedulers is None: + return + for _, sched_map in enumerate(self.schedulers): + for option, _ in sched_map.items(): + assert option in self.optimizer.defaults, ( + f"Optimizer option {option} not found in {self.optimizer}. " + f"Valid options are {self.optimizer.defaults.keys()}" + ) + + def step_schedulers(self, where: float) -> None: + if self.schedulers is None: + return + for i, param_group in enumerate(self.optimizer.param_groups): + for option, scheduler in self.schedulers[i].items(): + param_group[option] = scheduler(where) + + +# ----------------------------------------------------------------------------- +# Validation helpers +# ----------------------------------------------------------------------------- + + +def validate_param_group_params(param_groups: List[Dict], model: nn.Module): + """Ensure param groups are non-overlapping and include all model params.""" + + for pg in param_groups: + assert len(pg["params"]) == len(set(pg["params"])) + + parameters = [set(pg["params"]) for pg in param_groups] + model_parameters = {p for _, p in model.named_parameters()} + + for p1, p2 in itertools.permutations(parameters, 2): + assert p1.isdisjoint(p2), "Parameter groups should be disjoint" + + assert set.union(*parameters) == model_parameters, ( + "Parameter groups must cover ALL model parameters " + f"(found {len(set.union(*parameters))} / {len(model_parameters)})" + ) + + +# ----------------------------------------------------------------------------- +# Glob helpers for pattern matching +# ----------------------------------------------------------------------------- + +from wcmatch import fnmatch + +GLOB_FLAGS = ( + fnmatch.CASE # case-sensitive + | fnmatch.DOTMATCH # '*' also matches '.' + | fnmatch.EXTMATCH # extended patterns like *(foo|bar) + | fnmatch.SPLIT # "pat1|pat2" works out-of-the-box +) + + +def get_full_parameter_name(module_name: str, param_name: str) -> str: + return param_name if module_name == "" else f"{module_name}.{param_name}" + + +def get_module_cls_to_param_names(model: nn.Module) -> Dict[type, Set[str]]: + """Map each module class to the *immediate* param names it owns.""" + mapping: Dict[type, Set[str]] = {} + for module_name, module in model.named_modules(): + module_cls = type(module) + mapping.setdefault(module_cls, set()) + for pname, _ in module.named_parameters(recurse=False): + mapping[module_cls].add(get_full_parameter_name(module_name, pname)) + return mapping + + +def unix_param_pattern_to_parameter_names(filter_param_names: Union[List[str], None], + parameter_names: Set[str]) -> Set[str]: + if filter_param_names is None: + return set() + allowed = [] + for pat in filter_param_names: + matches = set(fnmatch.filter(parameter_names, pat, flags=GLOB_FLAGS)) + if not matches: + raise AssertionError(f"Pattern {pat} matched no parameters") + logging.info(f"Matches for param pattern [{pat}]: {matches}") + allowed.append(matches) + return set.union(*allowed) + + +def unix_module_cls_pattern_to_parameter_names(filter_module_cls_names: Union[List[str], None], + module_cls_to_param_names: Dict[type, Set[str]]) -> Set[str]: + if filter_module_cls_names is None: + return set() + allowed = [] + for cls_name in filter_module_cls_names: + module_cls = hydra.utils.get_class(cls_name) + if module_cls not in module_cls_to_param_names: + raise AssertionError(f"Module class {cls_name} not found in model") + params = module_cls_to_param_names[module_cls] + if not params: + raise AssertionError(f"Module class {cls_name} has no parameters") + logging.info(f"Matches for module [{cls_name}]: {params}") + allowed.append(params) + return set.union(*allowed) + + +def _unix_pattern_to_parameter_names(scheduler_cfg, + parameter_names: Set[str], + module_cls_to_param_names: Dict[type, Set[str]]): + if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg: + return None + return unix_param_pattern_to_parameter_names( + scheduler_cfg.get("param_names"), parameter_names + ).union( + unix_module_cls_pattern_to_parameter_names( + scheduler_cfg.get("module_cls_names"), module_cls_to_param_names + ) + ) + + +# ----------------------------------------------------------------------------- +# Scheduler helpers +# ----------------------------------------------------------------------------- + + +def set_default_parameters(scheduler_cfgs: List[dict], all_parameter_names: Set[str]): + """Ensure exactly one scheduler per option acts as the default.""" + specified = [cfg["parameter_names"] for cfg in scheduler_cfgs if cfg["parameter_names"]] + + default_params = ( + all_parameter_names if not specified else all_parameter_names - set.union(*specified) + ) + + default_count = 0 + for cfg in scheduler_cfgs: + if cfg["parameter_names"] is None: + cfg["parameter_names"] = default_params + default_count += 1 + assert default_count <= 1, "At most one default scheduler per option" + + if default_count == 0: + scheduler_cfgs.append({"parameter_names": default_params}) + + +def name_constraints_to_parameters(param_constraints: List[Set[str]], + named_parameters: Dict[str, Tensor]) -> List[Tensor]: + matching_names = set.intersection(*param_constraints) + return [v for k, v in named_parameters.items() if k in matching_names] + + +def map_scheduler_cfgs_to_param_groups(all_scheduler_cfgs: Iterable[List[dict]], + named_parameters: Dict[str, Tensor]): + """Produce param groups & schedulers that torch.optim can consume.""" + schedulers: List[Dict[str, Any]] = [] + param_groups: List[Dict[str, List[Tensor]]] = [] + + for cfgs in itertools.product(*all_scheduler_cfgs): + param_constraints = [cfg["parameter_names"] for cfg in cfgs] + matching = name_constraints_to_parameters(param_constraints, named_parameters) + if not matching: + continue # no intersection of params for this combo + schedulers.append({cfg["option"]: cfg["scheduler"] for cfg in cfgs if "option" in cfg}) + param_groups.append({"params": matching}) + + return schedulers, param_groups + + +# ----------------------------------------------------------------------------- +# Public factory functions +# ----------------------------------------------------------------------------- + + +def construct_optimizer(model: nn.Module, + optimizer_conf: Any, + options_conf: Union[Mapping[str, List], None] = None, + param_group_modifiers_conf: Union[List, None] = None, + validate_param_groups: bool = True) -> OptimizerWrapper: + """Build an OptimizerWrapper from hydra configs. + + *No* allowlist handling – we always optimize *all* model parameters. + """ + + named_parameters = dict(model.named_parameters()) + all_parameter_names = set(named_parameters.keys()) + module_cls_to_all_param_names = get_module_cls_to_param_names(model) + + # ────────────────────────────────────────────────────────────────── + # No scheduler case – simple & fast + # ────────────────────────────────────────────────────────────────── + if not options_conf: + optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values()) + return OptimizerWrapper(optimizer) + + # ────────────────────────────────────────────────────────────────── + # Build option-specific scheduler configs + # ────────────────────────────────────────────────────────────────── + scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf) + all_scheduler_cfgs: List[List[dict]] = [] + + for option, cfg_list in scheduler_cfgs_per_option.items(): + for cfg in cfg_list: + cfg.option = option # annotate + cfg.parameter_names = _unix_pattern_to_parameter_names( + cfg, all_parameter_names, module_cls_to_all_param_names + ) + set_default_parameters(cfg_list, all_parameter_names) + all_scheduler_cfgs.append(cfg_list) + + # User-provided modifiers (rare) + if param_group_modifiers_conf: + for modifier in param_group_modifiers_conf: + modifier = hydra.utils.instantiate(modifier) + all_scheduler_cfgs = modifier(scheduler_cfgs=all_scheduler_cfgs, model=model) + + # Map scheduler cfg combos to optimizer param groups + schedulers, param_groups = map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs, named_parameters + ) + + if validate_param_groups: + validate_param_group_params(param_groups, model) + + optimizer = hydra.utils.instantiate(optimizer_conf, param_groups) + return OptimizerWrapper(optimizer, schedulers) + + +def construct_optimizers(model: nn.Module, optim_conf) -> Union[List[OptimizerWrapper], None]: + """Convenience wrapper producing a *single* OptimizerWrapper list.""" + if optim_conf is None: + return None + + optimizer = construct_optimizer( + model, + optim_conf.optimizer, + optim_conf.options, + validate_param_groups=True, + ) + return [optimizer] diff --git a/vggt/training/train_utils/tb_writer.py b/vggt/training/train_utils/tb_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd54e950e3fdfd3b9ff2d86dbf3e57cfd25c198 --- /dev/null +++ b/vggt/training/train_utils/tb_writer.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import atexit +import logging +import uuid +from typing import Any, Dict, Optional, Union + +import torch +from torch.utils.tensorboard import SummaryWriter + +from .distributed import get_machine_local_and_dist_rank + + +class TensorBoardLogger: + """A wrapper around TensorBoard SummaryWriter with distributed training support. + + This logger only writes from rank 0 in distributed settings to avoid conflicts. + Automatically handles cleanup on exit. + """ + + def __init__( + self, + path: str, + *args: Any, + filename_suffix: Optional[str] = None, + summary_writer_method: Any = SummaryWriter, + **kwargs: Any, + ) -> None: + """Initialize TensorBoard logger. + + Args: + path: Directory path where TensorBoard logs will be stored + filename_suffix: Optional suffix for log filename. If None, uses random UUID + summary_writer_method: SummaryWriter class or compatible alternative + *args, **kwargs: Additional arguments passed to SummaryWriter + """ + self._writer: Optional[SummaryWriter] = None + _, self._rank = get_machine_local_and_dist_rank() + self._path: str = path + if self._rank == 0: + logging.info( + f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" + ) + self._writer = summary_writer_method( + log_dir=path, + *args, + filename_suffix=filename_suffix or str(uuid.uuid4()), + **kwargs, + ) + else: + logging.debug( + f"Not logging on this process because rank {self._rank} != 0" + ) + + atexit.register(self.close) + + @property + def writer(self) -> Optional[SummaryWriter]: + """Get the underlying SummaryWriter instance.""" + return self._writer + + @property + def path(self) -> str: + """Get the log directory path.""" + return self._path + + def flush(self) -> None: + """Write pending logs to disk.""" + if self._writer: + self._writer.flush() + + def close(self) -> None: + """Close writer and flush pending logs to disk. + + Logs cannot be written after close() is called. + """ + if self._writer: + self._writer.close() + self._writer = None + + def log_dict(self, payload: Dict[str, Any], step: int) -> None: + """Log multiple scalar values to TensorBoard. + + Args: + payload: Dictionary mapping tag names to scalar values + step: Step value to record + """ + if not self._writer: + return + + for key, value in payload.items(): + self.log(key, value, step) + + def log(self, name: str, data: Any, step: int) -> None: + """Log scalar data to TensorBoard. + + Args: + name: Tag name used to group scalars + data: Scalar data to log (float/int/Tensor) + step: Step value to record + """ + if not self._writer: + return + + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_visuals( + self, + name: str, + data: Union[torch.Tensor, Any], + step: int, + fps: int = 4 + ) -> None: + """Log image or video data to TensorBoard. + + Args: + name: Tag name used to group visuals + data: Image tensor (3D) or video tensor (5D) + step: Step value to record + fps: Frames per second for video data + + Raises: + ValueError: If data dimensions are not supported (must be 3D or 5D) + """ + if not self._writer: + return + + if data.ndim == 3: + self._writer.add_image(name, data, global_step=step) + elif data.ndim == 5: + self._writer.add_video(name, data, global_step=step, fps=fps) + else: + raise ValueError( + f"Unsupported data dimensions: {data.ndim}. " + "Expected 3D for images or 5D for videos." + ) diff --git a/vggt/training/trainer.py b/vggt/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..21ffa53e8210e108e6ee6a20a816aee8c9516971 --- /dev/null +++ b/vggt/training/trainer.py @@ -0,0 +1,868 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + + +# --- Environment Variable Setup for Performance and Debugging --- +# Helps with memory fragmentation in PyTorch's memory allocator. +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +# Specifies the threading layer for MKL, can prevent hangs in some environments. +os.environ["MKL_THREADING_LAYER"] = "GNU" +# Provides full Hydra stack traces on error for easier debugging. +os.environ["HYDRA_FULL_ERROR"] = "1" +# Enables asynchronous error handling for NCCL, which can prevent hangs. +os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + + +import contextlib +import gc +import json +import logging +import math +import time +from datetime import timedelta +from typing import Any, Dict, List, Mapping, Optional, Sequence + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr + +from train_utils.checkpoint import DDPCheckpointSaver +from train_utils.distributed import get_machine_local_and_dist_rank +from train_utils.freeze import freeze_modules +from train_utils.general import * +from train_utils.logging import setup_logging +from train_utils.normalization import normalize_camera_extrinsics_and_points_batch +from train_utils.optimizer import construct_optimizers + + +class Trainer: + """ + A generic trainer for DDP training. This should naturally support multi-node training. + + This class orchestrates the entire training and validation process, including: + - Setting up the distributed environment (DDP). + - Initializing the model, optimizers, loss functions, and data loaders. + - Handling checkpointing for resuming training. + - Executing the main training and validation loops. + - Logging metrics and visualizations to TensorBoard. + """ + + EPSILON = 1e-8 + + def __init__( + self, + *, + data: Dict[str, Any], + model: Dict[str, Any], + logging: Dict[str, Any], + checkpoint: Dict[str, Any], + max_epochs: int, + mode: str = "train", + device: str = "cuda", + seed_value: int = 123, + val_epoch_freq: int = 1, + distributed: Dict[str, bool] = None, + cuda: Dict[str, bool] = None, + limit_train_batches: Optional[int] = None, + limit_val_batches: Optional[int] = None, + optim: Optional[Dict[str, Any]] = None, + loss: Optional[Dict[str, Any]] = None, + env_variables: Optional[Dict[str, Any]] = None, + accum_steps: int = 1, + **kwargs, + ): + """ + Initializes the Trainer. + + Args: + data: Hydra config for datasets and dataloaders. + model: Hydra config for the model. + logging: Hydra config for logging (TensorBoard, log frequencies). + checkpoint: Hydra config for checkpointing. + max_epochs: Total number of epochs to train. + mode: "train" for training and validation, "val" for validation only. + device: "cuda" or "cpu". + seed_value: A random seed for reproducibility. + val_epoch_freq: Frequency (in epochs) to run validation. + distributed: Hydra config for DDP settings. + cuda: Hydra config for CUDA-specific settings (e.g., cuDNN). + limit_train_batches: Limit the number of training batches per epoch (for debugging). + limit_val_batches: Limit the number of validation batches per epoch (for debugging). + optim: Hydra config for optimizers and schedulers. + loss: Hydra config for the loss function. + env_variables: Dictionary of environment variables to set. + accum_steps: Number of steps to accumulate gradients before an optimizer step. + """ + self._setup_env_variables(env_variables) + self._setup_timers() + + # Store Hydra configurations + self.data_conf = data + self.model_conf = model + self.loss_conf = loss + self.logging_conf = logging + self.checkpoint_conf = checkpoint + self.optim_conf = optim + + # Store hyperparameters + self.accum_steps = accum_steps + self.max_epochs = max_epochs + self.mode = mode + self.val_epoch_freq = val_epoch_freq + self.limit_train_batches = limit_train_batches + self.limit_val_batches = limit_val_batches + self.seed_value = seed_value + + # 'where' tracks training progress from 0.0 to 1.0 for schedulers + self.where = 0.0 + + self._setup_device(device) + self._setup_torch_dist_and_backend(cuda, distributed) + + # Setup logging directory and configure logger + safe_makedirs(self.logging_conf.log_dir) + setup_logging( + __name__, + output_dir=self.logging_conf.log_dir, + rank=self.rank, + log_level_primary=self.logging_conf.log_level_primary, + log_level_secondary=self.logging_conf.log_level_secondary, + all_ranks=self.logging_conf.all_ranks, + ) + set_seeds(seed_value, self.max_epochs, self.distributed_rank) + + assert is_dist_avail_and_initialized(), "Torch distributed needs to be initialized before calling the trainer." + + # Instantiate components (model, loss, etc.) + self._setup_components() + self._setup_dataloaders() + + # Move model to the correct device + self.model.to(self.device) + self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.4f") + + # Construct optimizers (after moving model to device) + if self.mode != "val": + self.optims = construct_optimizers(self.model, self.optim_conf) + + # Load checkpoint if available or specified + if self.checkpoint_conf.resume_checkpoint_path is not None: + self._load_resuming_checkpoint(self.checkpoint_conf.resume_checkpoint_path) + else: + ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir) + if ckpt_path is not None: + self._load_resuming_checkpoint(ckpt_path) + + # Wrap the model with DDP + self._setup_ddp_distributed_training(distributed, device) + + # Barrier to ensure all processes are synchronized before starting + dist.barrier() + + def _setup_timers(self): + """Initializes timers for tracking total elapsed time.""" + self.start_time = time.time() + self.ckpt_time_elapsed = 0 + + def _setup_env_variables(self, env_variables_conf: Optional[Dict[str, Any]]) -> None: + """Sets environment variables from the configuration.""" + if env_variables_conf: + for variable_name, value in env_variables_conf.items(): + os.environ[variable_name] = value + logging.info(f"Environment:\n{json.dumps(dict(os.environ), sort_keys=True, indent=2)}") + + def _setup_torch_dist_and_backend(self, cuda_conf: Dict, distributed_conf: Dict) -> None: + """Initializes the distributed process group and configures PyTorch backends.""" + if torch.cuda.is_available(): + # Configure CUDA backend settings for performance + torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic + torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark + torch.backends.cuda.matmul.allow_tf32 = cuda_conf.allow_tf32 + torch.backends.cudnn.allow_tf32 = cuda_conf.allow_tf32 + + # Initialize the DDP process group + dist.init_process_group( + backend=distributed_conf.backend, + timeout=timedelta(minutes=distributed_conf.timeout_mins) + ) + self.rank = dist.get_rank() + + def _load_resuming_checkpoint(self, ckpt_path: str): + """Loads a checkpoint from the given path to resume training.""" + logging.info(f"Resuming training from {ckpt_path} (rank {self.rank})") + + with g_pathmgr.open(ckpt_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + + # Load model state + model_state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint + missing, unexpected = self.model.load_state_dict( + model_state_dict, strict=self.checkpoint_conf.strict + ) + if self.rank == 0: + logging.info(f"Model state loaded. Missing keys: {missing or 'None'}. Unexpected keys: {unexpected or 'None'}.") + + # Load optimizer state if available and in training mode + if "optimizer" in checkpoint: + logging.info(f"Loading optimizer state dict (rank {self.rank})") + self.optims.optimizer.load_state_dict(checkpoint["optimizer"]) + + # Load training progress + if "epoch" in checkpoint: + self.epoch = checkpoint["epoch"] + self.steps = checkpoint["steps"] if "steps" in checkpoint else {"train": 0, "val": 0} + self.ckpt_time_elapsed = checkpoint.get("time_elapsed", 0) + + # Load AMP scaler state if available + if self.optim_conf.amp.enabled and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + def _setup_device(self, device: str): + """Sets up the device for training (CPU or CUDA).""" + self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank() + if device == "cuda": + self.device = torch.device("cuda", self.local_rank) + torch.cuda.set_device(self.local_rank) + elif device == "cpu": + self.device = torch.device("cpu") + else: + raise ValueError(f"Unsupported device: {device}") + + def _setup_components(self): + """Initializes all core training components using Hydra configs.""" + logging.info("Setting up components: Model, Loss, Logger, etc.") + self.epoch = 0 + self.steps = {'train': 0, 'val': 0} + + # Instantiate components from configs + self.tb_writer = instantiate(self.logging_conf.tensorboard_writer, _recursive_=False) + self.model = instantiate(self.model_conf, _recursive_=False) + self.loss = instantiate(self.loss_conf, _recursive_=False) + self.gradient_clipper = instantiate(self.optim_conf.gradient_clip) + self.scaler = torch.cuda.amp.GradScaler(enabled=self.optim_conf.amp.enabled) + + # Freeze specified model parameters if any + if getattr(self.optim_conf, "frozen_module_names", None): + logging.info( + f"[Start] Freezing modules: {self.optim_conf.frozen_module_names} on rank {self.distributed_rank}" + ) + self.model = freeze_modules( + self.model, + patterns=self.optim_conf.frozen_module_names, + ) + logging.info( + f"[Done] Freezing modules: {self.optim_conf.frozen_module_names} on rank {self.distributed_rank}" + ) + + # Log model summary on rank 0 + if self.rank == 0: + model_summary_path = os.path.join(self.logging_conf.log_dir, "model.txt") + model_summary(self.model, log_file=model_summary_path) + logging.info(f"Model summary saved to {model_summary_path}") + + logging.info("Successfully initialized training components.") + + def _setup_dataloaders(self): + """Initializes train and validation datasets and dataloaders.""" + self.train_dataset = None + self.val_dataset = None + + if self.mode in ["train", "val"]: + self.val_dataset = instantiate( + self.data_conf.get('val', None), _recursive_=False + ) + if self.val_dataset is not None: + self.val_dataset.seed = self.seed_value + + if self.mode in ["train"]: + self.train_dataset = instantiate(self.data_conf.train, _recursive_=False) + self.train_dataset.seed = self.seed_value + + def _setup_ddp_distributed_training(self, distributed_conf: Dict, device: str): + """Wraps the model with DistributedDataParallel (DDP).""" + assert isinstance(self.model, torch.nn.Module) + + ddp_options = dict( + find_unused_parameters=distributed_conf.find_unused_parameters, + gradient_as_bucket_view=distributed_conf.gradient_as_bucket_view, + bucket_cap_mb=distributed_conf.bucket_cap_mb, + broadcast_buffers=distributed_conf.broadcast_buffers, + ) + + self.model = nn.parallel.DistributedDataParallel( + self.model, + device_ids=[self.local_rank] if device == "cuda" else [], + **ddp_options, + ) + + def save_checkpoint(self, epoch: int, checkpoint_names: Optional[List[str]] = None): + """ + Saves a training checkpoint. + + Args: + epoch: The current epoch number. + checkpoint_names: A list of names for the checkpoint file (e.g., "checkpoint_latest"). + If None, saves "checkpoint" and "checkpoint_{epoch}" on frequency. + """ + checkpoint_folder = self.checkpoint_conf.save_dir + safe_makedirs(checkpoint_folder) + if checkpoint_names is None: + checkpoint_names = ["checkpoint"] + if ( + self.checkpoint_conf.save_freq > 0 + and int(epoch) % self.checkpoint_conf.save_freq == 0 + and (int(epoch) > 0 or self.checkpoint_conf.save_freq == 1) + ): + checkpoint_names.append(f"checkpoint_{int(epoch)}") + + checkpoint_content = { + "prev_epoch": epoch, + "steps": self.steps, + "time_elapsed": self.time_elapsed_meter.val, + "optimizer": [optim.optimizer.state_dict() for optim in self.optims], + } + + if len(self.optims) == 1: + checkpoint_content["optimizer"] = checkpoint_content["optimizer"][0] + if self.optim_conf.amp.enabled: + checkpoint_content["scaler"] = self.scaler.state_dict() + + # Save the checkpoint for DDP only + saver = DDPCheckpointSaver( + checkpoint_folder, + checkpoint_names=checkpoint_names, + rank=self.distributed_rank, + epoch=epoch, + ) + + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model = self.model.module + + saver.save_checkpoint( + model=model, + ema_models = None, + skip_saving_parameters=[], + **checkpoint_content, + ) + + + + + def _get_scalar_log_keys(self, phase: str) -> List[str]: + """Retrieves keys for scalar values to be logged for a given phase.""" + if self.logging_conf.scalar_keys_to_log: + return self.logging_conf.scalar_keys_to_log[phase].keys_to_log + return [] + + def run(self): + """Main entry point to start the training or validation process.""" + assert self.mode in ["train", "val"], f"Invalid mode: {self.mode}" + if self.mode == "train": + self.run_train() + # Optionally run a final validation after all training is done + self.run_val() + elif self.mode == "val": + self.run_val() + else: + raise ValueError(f"Invalid mode: {self.mode}") + + def run_train(self): + """Runs the main training loop over all epochs.""" + while self.epoch < self.max_epochs: + set_seeds(self.seed_value + self.epoch * 100, self.max_epochs, self.distributed_rank) + + dataloader = self.train_dataset.get_loader(epoch=int(self.epoch + self.distributed_rank)) + self.train_epoch(dataloader) + + # Save checkpoint after each training epoch + self.save_checkpoint(self.epoch) + + # Clean up memory + del dataloader + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Run validation at the specified frequency + # Skips validation after the last training epoch, as it can be run separately. + if self.epoch % self.val_epoch_freq == 0 and self.epoch < self.max_epochs - 1: + self.run_val() + + self.epoch += 1 + + self.epoch -= 1 + + def run_val(self): + """Runs a full validation epoch if a validation dataset is available.""" + if not self.val_dataset: + logging.info("No validation dataset configured. Skipping validation.") + return + + dataloader = self.val_dataset.get_loader(epoch=int(self.epoch + self.distributed_rank)) + self.val_epoch(dataloader) + + del dataloader + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + + @torch.no_grad() + def val_epoch(self, val_loader): + batch_time = AverageMeter("Batch Time", self.device, ":.4f") + data_time = AverageMeter("Data Time", self.device, ":.4f") + mem = AverageMeter("Mem (GB)", self.device, ":.4f") + data_times = [] + phase = 'val' + + loss_names = self._get_scalar_log_keys(phase) + loss_names = [f"Loss/{phase}_{name}" for name in loss_names] + loss_meters = { + name: AverageMeter(name, self.device, ":.4f") for name in loss_names + } + + progress = ProgressMeter( + num_batches=len(val_loader), + meters=[ + batch_time, + data_time, + mem, + self.time_elapsed_meter, + *loss_meters.values(), + ], + real_meters={}, + prefix="Val Epoch: [{}]".format(self.epoch), + ) + + self.model.eval() + end = time.time() + + iters_per_epoch = len(val_loader) + limit_val_batches = ( + iters_per_epoch + if self.limit_val_batches is None + else self.limit_val_batches + ) + + for data_iter, batch in enumerate(val_loader): + if data_iter > limit_val_batches: + break + + # measure data loading time + data_time.update(time.time() - end) + data_times.append(data_time.val) + + with torch.cuda.amp.autocast(enabled=False): + batch = self._process_batch(batch) + batch = copy_data_to_device(batch, self.device, non_blocking=True) + + amp_type = self.optim_conf.amp.amp_dtype + assert amp_type in ["bfloat16", "float16"], f"Invalid Amp type: {amp_type}" + if amp_type == "bfloat16": + amp_type = torch.bfloat16 + else: + amp_type = torch.float16 + + # compute output + with torch.no_grad(): + with torch.cuda.amp.autocast( + enabled=self.optim_conf.amp.enabled, + dtype=amp_type, + ): + val_loss_dict = self._step( + batch, self.model, phase, loss_meters + ) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + if torch.cuda.is_available(): + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + + return True + + def train_epoch(self, train_loader): + batch_time = AverageMeter("Batch Time", self.device, ":.4f") + data_time = AverageMeter("Data Time", self.device, ":.4f") + mem = AverageMeter("Mem (GB)", self.device, ":.4f") + data_times = [] + phase = 'train' + + loss_names = self._get_scalar_log_keys(phase) + loss_names = [f"Loss/{phase}_{name}" for name in loss_names] + loss_meters = { + name: AverageMeter(name, self.device, ":.4f") for name in loss_names + } + + for config in self.gradient_clipper.configs: + param_names = ",".join(config['module_names']) + loss_meters[f"Grad/{param_names}"] = AverageMeter(f"Grad/{param_names}", self.device, ":.4f") + + + progress = ProgressMeter( + num_batches=len(train_loader), + meters=[ + batch_time, + data_time, + mem, + self.time_elapsed_meter, + *loss_meters.values(), + ], + real_meters={}, + prefix="Train Epoch: [{}]".format(self.epoch), + ) + + self.model.train() + end = time.time() + + iters_per_epoch = len(train_loader) + limit_train_batches = ( + iters_per_epoch + if self.limit_train_batches is None + else self.limit_train_batches + ) + + if self.gradient_clipper is not None: + # setup gradient clipping at the beginning of training + self.gradient_clipper.setup_clipping(self.model) + + for data_iter, batch in enumerate(train_loader): + if data_iter > limit_train_batches: + break + + # measure data loading time + data_time.update(time.time() - end) + data_times.append(data_time.val) + + + with torch.cuda.amp.autocast(enabled=False): + batch = self._process_batch(batch) + + batch = copy_data_to_device(batch, self.device, non_blocking=True) + + accum_steps = self.accum_steps + + if accum_steps==1: + chunked_batches = [batch] + else: + chunked_batches = chunk_batch_for_accum_steps(batch, accum_steps) + + self._run_steps_on_batch_chunks( + chunked_batches, phase, loss_meters + ) + + # compute gradient and do SGD step + assert data_iter <= limit_train_batches # allow for off by one errors + exact_epoch = self.epoch + float(data_iter) / limit_train_batches + self.where = float(exact_epoch) / self.max_epochs + + assert self.where <= 1 + self.EPSILON + if self.where < 1.0: + for optim in self.optims: + optim.step_schedulers(self.where) + else: + logging.warning( + f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]." + ) + + # Log schedulers + if self.steps[phase] % self.logging_conf.log_freq == 0: + for i, optim in enumerate(self.optims): + for j, param_group in enumerate(optim.optimizer.param_groups): + for option in optim.schedulers[j]: + optim_prefix = ( + f"{i}_" + if len(self.optims) > 1 + else ( + "" + f"{j}_" + if len(optim.optimizer.param_groups) > 1 + else "" + ) + ) + self.tb_writer.log( + os.path.join("Optim", f"{optim_prefix}", option), + param_group[option], + self.steps[phase], + ) + self.tb_writer.log( + os.path.join("Optim", "where"), + self.where, + self.steps[phase], + ) + + # Clipping gradients and detecting diverging gradients + if self.gradient_clipper is not None: + for optim in self.optims: + self.scaler.unscale_(optim.optimizer) + + grad_norm_dict = self.gradient_clipper(model=self.model) + + for key, grad_norm in grad_norm_dict.items(): + loss_meters[f"Grad/{key}"].update(grad_norm) + + # Optimizer step + for optim in self.optims: + self.scaler.step(optim.optimizer) + self.scaler.update() + + # Measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + mem.update(torch.cuda.max_memory_allocated() // 1e9) + + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + return True + + def _run_steps_on_batch_chunks( + self, + chunked_batches: List[Any], + phase: str, + loss_meters: Dict[str, AverageMeter], + ): + """ + Run the forward / backward as many times as there are chunks in the batch, + accumulating the gradients on each backward + """ + + for optim in self.optims: + optim.zero_grad(set_to_none=True) + + accum_steps = len(chunked_batches) + + amp_type = self.optim_conf.amp.amp_dtype + assert amp_type in ["bfloat16", "float16"], f"Invalid Amp type: {amp_type}" + if amp_type == "bfloat16": + amp_type = torch.bfloat16 + else: + amp_type = torch.float16 + + for i, chunked_batch in enumerate(chunked_batches): + ddp_context = ( + self.model.no_sync() + if i < accum_steps - 1 + else contextlib.nullcontext() + ) + + with ddp_context: + with torch.cuda.amp.autocast( + enabled=self.optim_conf.amp.enabled, + dtype=amp_type, + ): + loss_dict = self._step( + chunked_batch, self.model, phase, loss_meters + ) + + + loss = loss_dict["objective"] + loss_key = f"Loss/{phase}_loss_objective" + batch_size = chunked_batch["images"].shape[0] + + if not math.isfinite(loss.item()): + error_msg = f"Loss is {loss.item()}, attempting to stop training" + logging.error(error_msg) + return + + loss /= accum_steps + self.scaler.scale(loss).backward() + loss_meters[loss_key].update(loss.item(), batch_size) + + + def _apply_batch_repetition(self, batch: Mapping) -> Mapping: + """ + Applies a data augmentation by concatenating the original batch with a + flipped version of itself. + """ + tensor_keys = [ + "images", "depths", "extrinsics", "intrinsics", + "cam_points", "world_points", "point_masks", + ] + string_keys = ["seq_name"] + + for key in tensor_keys: + if key in batch: + original_tensor = batch[key] + batch[key] = torch.concatenate([original_tensor, + torch.flip(original_tensor, dims=[1])], + dim=0) + + for key in string_keys: + if key in batch: + batch[key] = batch[key] * 2 + + return batch + + def _process_batch(self, batch: Mapping): + if self.data_conf.train.common_config.repeat_batch: + batch = self._apply_batch_repetition(batch) + + # Normalize camera extrinsics and points. The function returns new tensors. + normalized_extrinsics, normalized_cam_points, normalized_world_points, normalized_depths = \ + normalize_camera_extrinsics_and_points_batch( + extrinsics=batch["extrinsics"], + cam_points=batch["cam_points"], + world_points=batch["world_points"], + depths=batch["depths"], + point_masks=batch["point_masks"], + ) + + # Replace the original values in the batch with the normalized ones. + batch["extrinsics"] = normalized_extrinsics + batch["cam_points"] = normalized_cam_points + batch["world_points"] = normalized_world_points + batch["depths"] = normalized_depths + + return batch + + def _step(self, batch, model: nn.Module, phase: str, loss_meters: dict): + """ + Performs a single forward pass, computes loss, and logs results. + + Returns: + A dictionary containing the computed losses. + """ + # Forward pass + y_hat = model(images=batch["images"]) + + # Loss computation + loss_dict = self.loss(y_hat, batch) + + # Combine all data for logging + log_data = {**y_hat, **loss_dict, **batch} + + self._update_and_log_scalars(log_data, phase, self.steps[phase], loss_meters) + self._log_tb_visuals(log_data, phase, self.steps[phase]) + + self.steps[phase] += 1 + return loss_dict + + def _update_and_log_scalars(self, data: Mapping, phase: str, step: int, loss_meters: dict): + """Updates average meters and logs scalar values to TensorBoard.""" + keys_to_log = self._get_scalar_log_keys(phase) + batch_size = data['extrinsics'].shape[0] + + for key in keys_to_log: + if key in data: + value = data[key].item() if torch.is_tensor(data[key]) else data[key] + loss_meters[f"Loss/{phase}_{key}"].update(value, batch_size) + if step % self.logging_conf.log_freq == 0 and self.rank == 0: + self.tb_writer.log(f"Values/{phase}/{key}", value, step) + + def _log_tb_visuals(self, batch: Mapping, phase: str, step: int) -> None: + """Logs image or video visualizations to TensorBoard.""" + if not ( + self.logging_conf.log_visuals + and (phase in self.logging_conf.log_visual_frequency) + and self.logging_conf.log_visual_frequency[phase] > 0 + and (step % self.logging_conf.log_visual_frequency[phase] == 0) + and (self.logging_conf.visuals_keys_to_log is not None) + ): + return + + if phase in self.logging_conf.visuals_keys_to_log: + keys_to_log = self.logging_conf.visuals_keys_to_log[phase][ + "keys_to_log" + ] + assert ( + len(keys_to_log) > 0 + ), "Need to include some visual keys to log" + modality = self.logging_conf.visuals_keys_to_log[phase][ + "modality" + ] + assert modality in [ + "image", + "video", + ], "Currently only support video or image logging" + + name = f"Visuals/{phase}" + + visuals_to_log = torchvision.utils.make_grid( + [ + torchvision.utils.make_grid( + batch[key][0], # Ensure batch[key][0] is tensor and has at least 3 dimensions + nrow=self.logging_conf.visuals_per_batch_to_log, + ) + for key in keys_to_log if key in batch and batch[key][0].dim() >= 3 + ], + nrow=1, + ).clamp(-1, 1) + + visuals_to_log = visuals_to_log.cpu() + if visuals_to_log.dtype == torch.bfloat16: + visuals_to_log = visuals_to_log.to(torch.float16) + visuals_to_log = visuals_to_log.numpy() + + self.tb_writer.log_visuals( + name, visuals_to_log, step, self.logging_conf.video_logging_fps + ) + + + + +def chunk_batch_for_accum_steps(batch: Mapping, accum_steps: int) -> List[Mapping]: + """Splits a batch into smaller chunks for gradient accumulation.""" + if accum_steps == 1: + return [batch] + return [get_chunk_from_data(batch, i, accum_steps) for i in range(accum_steps)] + +def is_sequence_of_primitives(data: Any) -> bool: + """Checks if data is a sequence of primitive types (str, int, float, bool).""" + return ( + isinstance(data, Sequence) + and not isinstance(data, str) + and len(data) > 0 + and isinstance(data[0], (str, int, float, bool)) + ) + +def get_chunk_from_data(data: Any, chunk_id: int, num_chunks: int) -> Any: + """ + Recursively splits tensors and sequences within a data structure into chunks. + + Args: + data: The data structure to split (e.g., a dictionary of tensors). + chunk_id: The index of the chunk to retrieve. + num_chunks: The total number of chunks to split the data into. + + Returns: + A chunk of the original data structure. + """ + if isinstance(data, torch.Tensor) or is_sequence_of_primitives(data): + # either a tensor or a list of primitive objects + # assert len(data) % num_chunks == 0 + start = (len(data) // num_chunks) * chunk_id + end = (len(data) // num_chunks) * (chunk_id + 1) + return data[start:end] + elif isinstance(data, Mapping): + return { + key: get_chunk_from_data(value, chunk_id, num_chunks) + for key, value in data.items() + } + elif isinstance(data, str): + # NOTE: this is a hack to support string keys in the batch + return data + elif isinstance(data, Sequence): + return [get_chunk_from_data(value, chunk_id, num_chunks) for value in data] + else: + return data + diff --git a/vggt/vggt.egg-info/PKG-INFO b/vggt/vggt.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..e24f2a0b5b899bc109734ddfe59b20166330407a --- /dev/null +++ b/vggt/vggt.egg-info/PKG-INFO @@ -0,0 +1,25 @@ +Metadata-Version: 2.4 +Name: vggt +Version: 0.0.1 +Author-email: Jianyuan Wang +Requires-Python: >=3.10 +License-File: LICENSE.txt +Requires-Dist: numpy<2 +Requires-Dist: Pillow +Requires-Dist: huggingface_hub +Requires-Dist: einops +Requires-Dist: safetensors +Requires-Dist: opencv-python +Provides-Extra: demo +Requires-Dist: gradio==5.17.1; extra == "demo" +Requires-Dist: viser==0.2.23; extra == "demo" +Requires-Dist: tqdm; extra == "demo" +Requires-Dist: hydra-core; extra == "demo" +Requires-Dist: omegaconf; extra == "demo" +Requires-Dist: opencv-python; extra == "demo" +Requires-Dist: scipy; extra == "demo" +Requires-Dist: onnxruntime; extra == "demo" +Requires-Dist: requests; extra == "demo" +Requires-Dist: trimesh; extra == "demo" +Requires-Dist: matplotlib; extra == "demo" +Dynamic: license-file diff --git a/vggt/vggt.egg-info/SOURCES.txt b/vggt/vggt.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..3bb2ca1cf5583d4d999dd9a55e26f634512deb38 --- /dev/null +++ b/vggt/vggt.egg-info/SOURCES.txt @@ -0,0 +1,48 @@ +LICENSE.txt +pyproject.toml +vggt.egg-info/PKG-INFO +vggt.egg-info/SOURCES.txt +vggt.egg-info/dependency_links.txt +vggt.egg-info/requires.txt +vggt.egg-info/top_level.txt +vggt/dependency/__init__.py +vggt/dependency/distortion.py +vggt/dependency/np_to_pycolmap.py +vggt/dependency/projection.py +vggt/dependency/track_predict.py +vggt/dependency/vggsfm_tracker.py +vggt/dependency/vggsfm_utils.py +vggt/dependency/track_modules/__init__.py +vggt/dependency/track_modules/base_track_predictor.py +vggt/dependency/track_modules/blocks.py +vggt/dependency/track_modules/modules.py +vggt/dependency/track_modules/track_refine.py +vggt/dependency/track_modules/utils.py +vggt/heads/camera_head.py +vggt/heads/dpt_head.py +vggt/heads/head_act.py +vggt/heads/track_head.py +vggt/heads/utils.py +vggt/heads/track_modules/__init__.py +vggt/heads/track_modules/base_track_predictor.py +vggt/heads/track_modules/blocks.py +vggt/heads/track_modules/modules.py +vggt/heads/track_modules/utils.py +vggt/layers/__init__.py +vggt/layers/attention.py +vggt/layers/block.py +vggt/layers/drop_path.py +vggt/layers/layer_scale.py +vggt/layers/mlp.py +vggt/layers/patch_embed.py +vggt/layers/rope.py +vggt/layers/swiglu_ffn.py +vggt/layers/vision_transformer.py +vggt/models/aggregator.py +vggt/models/vggt.py +vggt/utils/geometry.py +vggt/utils/helper.py +vggt/utils/load_fn.py +vggt/utils/pose_enc.py +vggt/utils/rotation.py +vggt/utils/visual_track.py \ No newline at end of file diff --git a/vggt/vggt.egg-info/dependency_links.txt b/vggt/vggt.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/vggt/vggt.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/vggt/vggt.egg-info/requires.txt b/vggt/vggt.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..79328113204192a97f41f5a261ced4272a547450 --- /dev/null +++ b/vggt/vggt.egg-info/requires.txt @@ -0,0 +1,19 @@ +numpy<2 +Pillow +huggingface_hub +einops +safetensors +opencv-python + +[demo] +gradio==5.17.1 +viser==0.2.23 +tqdm +hydra-core +omegaconf +opencv-python +scipy +onnxruntime +requests +trimesh +matplotlib diff --git a/vggt/vggt.egg-info/top_level.txt b/vggt/vggt.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..025c42e0f226e8e101526f2d4662a1e78a600844 --- /dev/null +++ b/vggt/vggt.egg-info/top_level.txt @@ -0,0 +1 @@ +vggt diff --git a/vggt/vggt/__init__.py b/vggt/vggt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/vggt/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60712757acad724925a662f44092514ef74779ac Binary files /dev/null and b/vggt/vggt/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/__init__.py b/vggt/vggt/dependency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4eab06de2c911398e0339782e327bffd4cc9f91c --- /dev/null +++ b/vggt/vggt/dependency/__init__.py @@ -0,0 +1,3 @@ +from .track_modules.track_refine import refine_track +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor diff --git a/vggt/vggt/dependency/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/dependency/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10688c1a27310adf79969f38975df660e108a77d Binary files /dev/null and b/vggt/vggt/dependency/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/__pycache__/distortion.cpython-39.pyc b/vggt/vggt/dependency/__pycache__/distortion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d885b0667c0e4d84e424b70d105a74667ccb258f Binary files /dev/null and b/vggt/vggt/dependency/__pycache__/distortion.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/distortion.py b/vggt/vggt/dependency/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..b3510230265dbd088844076e9d5763a35f7d712b --- /dev/null +++ b/vggt/vggt/dependency/distortion.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from typing import Union + +ArrayLike = Union[np.ndarray, torch.Tensor] + + +def _is_numpy(x: ArrayLike) -> bool: + return isinstance(x, np.ndarray) + + +def _is_torch(x: ArrayLike) -> bool: + return isinstance(x, torch.Tensor) + + +def _ensure_torch(x: ArrayLike) -> torch.Tensor: + """Convert input to torch tensor if it's not already one.""" + if _is_numpy(x): + return torch.from_numpy(x) + elif _is_torch(x): + return x + else: + return torch.tensor(x) + + +def single_undistortion(params, tracks_normalized): + """ + Apply undistortion to the normalized tracks using the given distortion parameters once. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + u_undist, v_undist = apply_distortion(params, u, v) + return torch.stack([u_undist, v_undist], dim=-1) + + +def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6): + """ + Iteratively undistort the normalized tracks using the given distortion parameters. + + Args: + params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN. + tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2]. + max_iterations (int): Maximum number of iterations for the undistortion process. + max_step_norm (float): Maximum step norm for convergence. + rel_step_size (float): Relative step size for numerical differentiation. + + Returns: + torch.Tensor: Undistorted normalized tracks tensor. + """ + params = _ensure_torch(params) + tracks_normalized = _ensure_torch(tracks_normalized) + + B, N, _ = tracks_normalized.shape + u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone() + original_u, original_v = u.clone(), v.clone() + + eps = torch.finfo(u.dtype).eps + for idx in range(max_iterations): + u_undist, v_undist = apply_distortion(params, u, v) + dx = original_u - u_undist + dy = original_v - v_undist + + step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps) + step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps) + + J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u) + J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v) + J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u) + J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v) + + J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2) + + delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1)) + + u += delta[..., 0] + v += delta[..., 1] + + if torch.max((delta**2).sum(dim=-1)) < max_step_norm: + break + + return torch.stack([u, v], dim=-1) + + +def apply_distortion(extra_params, u, v): + """ + Applies radial or OpenCV distortion to the given 2D points. + + Args: + extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks. + v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks. + + Returns: + points2D (torch.Tensor): Distorted 2D points of shape BxNx2. + """ + extra_params = _ensure_torch(extra_params) + u = _ensure_torch(u) + v = _ensure_torch(v) + + num_params = extra_params.shape[1] + + if num_params == 1: + # Simple radial distortion + k = extra_params[:, 0] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k[:, None] * r2 + du = u * radial + dv = v * radial + + elif num_params == 2: + # RadialCameraModel distortion + k1, k2 = extra_params[:, 0], extra_params[:, 1] + u2 = u * u + v2 = v * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + dv = v * radial + + elif num_params == 4: + # OpenCVCameraModel distortion + k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3]) + u2 = u * u + v2 = v * v + uv = u * v + r2 = u2 + v2 + radial = k1[:, None] * r2 + k2[:, None] * r2 * r2 + du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2) + dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2) + else: + raise ValueError("Unsupported number of distortion parameters") + + u = u.clone() + du + v = v.clone() + dv + + return u, v + + +if __name__ == "__main__": + import random + import pycolmap + + max_diff = 0 + for i in range(1000): + # Define distortion parameters (assuming 1 parameter for simplicity) + B = random.randint(1, 500) + track_num = random.randint(100, 1000) + params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters + tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32) # Batch size 1, 5 points + + # Undistort the tracks + undistorted_tracks = iterative_undistortion(params, tracks_normalized) + + for b in range(B): + pycolmap_intri = np.array([1, 0, 0, params[b].item()]) + pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0) + + undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy()) + diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median() + max_diff = max(max_diff, diff) + print(f"diff: {diff}, max_diff: {max_diff}") + + import pdb + + pdb.set_trace() diff --git a/vggt/vggt/dependency/np_to_pycolmap.py b/vggt/vggt/dependency/np_to_pycolmap.py new file mode 100644 index 0000000000000000000000000000000000000000..61ea578692d2b5a5cd5b6fd15836373a94351489 --- /dev/null +++ b/vggt/vggt/dependency/np_to_pycolmap.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pycolmap +from .projection import project_3D_points_np + + +def batch_np_matrix_to_pycolmap( + points3d, + extrinsics, + intrinsics, + tracks, + image_size, + masks=None, + max_reproj_error=None, + max_points3D_val=3000, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", + extra_params=None, + min_inlier_per_frame=64, + points_rgb=None, +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Check https://github.com/colmap/pycolmap for more details about its format + + NOTE that colmap expects images/cameras/points3D to be 1-indexed + so there is a +1 offset between colmap index and batch index + + + NOTE: different from VGGSfM, this function: + 1. Use np instead of torch + 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP) + """ + # points3d: Px3 + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # tracks: NxPx2 + # masks: NxP + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N, P, _ = tracks.shape + assert len(extrinsics) == N + assert len(intrinsics) == N + assert len(points3d) == P + assert image_size.shape[0] == 2 + + reproj_mask = None + + if max_reproj_error is not None: + projected_points_2d, projected_points_cam = project_3D_points_np(points3d, extrinsics, intrinsics) + projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1) + projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6 + reproj_mask = projected_diff < max_reproj_error + + if masks is not None and reproj_mask is not None: + masks = np.logical_and(masks, reproj_mask) + elif masks is not None: + masks = masks + else: + masks = reproj_mask + + assert masks is not None + + if masks.sum(1).min() < min_inlier_per_frame: + print(f"Not enough inliers per frame, skip BA.") + return None, None + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + inlier_num = masks.sum(0) + valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + # Only add 3D points that have sufficient 2D points + for vidx in valid_idx: + # Use RGB colors if provided, otherwise use zeros + rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3) + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb) + + num_points3D = len(valid_idx) + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params) + + camera = pycolmap.Camera( + model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world + ) + + points2D_list = [] + + point2D_idx = 0 + + # NOTE point3D_id start by 1 + for point3D_id in range(1, num_points3D + 1): + original_track_idx = valid_idx[point3D_id - 1] + + if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): + if masks[fidx][original_track_idx]: + # It seems we don't need +0.5 for BA + point2D_xy = tracks[fidx][original_track_idx] + # Please note when adding the Point2D object + # It not only requires the 2D xy location, but also the id to 3D point + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: + print(f"frame {fidx + 1} is out of BA") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction, valid_mask + + +def pycolmap_to_batch_np_matrix(reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"): + """ + Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays. + + Args: + reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP. + device (str): Ignored in NumPy version (kept for API compatibility). + camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE"). + + Returns: + tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params. + """ + + num_images = len(reconstruction.images) + max_points3D_id = max(reconstruction.point3D_ids()) + points3D = np.zeros((max_points3D_id, 3)) + + for point3D_id in reconstruction.points3D: + points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz + + extrinsics = [] + intrinsics = [] + + extra_params = [] if camera_type == "SIMPLE_RADIAL" else None + + for i in range(num_images): + # Extract and append extrinsics + pyimg = reconstruction.images[i + 1] + pycam = reconstruction.cameras[pyimg.camera_id] + matrix = pyimg.cam_from_world.matrix() + extrinsics.append(matrix) + + # Extract and append intrinsics + calibration_matrix = pycam.calibration_matrix() + intrinsics.append(calibration_matrix) + + if camera_type == "SIMPLE_RADIAL": + extra_params.append(pycam.params[-1]) + + # Convert lists to NumPy arrays instead of torch tensors + extrinsics = np.stack(extrinsics) + intrinsics = np.stack(intrinsics) + + if camera_type == "SIMPLE_RADIAL": + extra_params = np.stack(extra_params) + extra_params = extra_params[:, None] + + return points3D, extrinsics, intrinsics, extra_params + + +######################################################## + + +def batch_np_matrix_to_pycolmap_wo_track( + points3d, + points_xyf, + points_rgb, + extrinsics, + intrinsics, + image_size, + shared_camera=False, + camera_type="SIMPLE_PINHOLE", +): + """ + Convert Batched NumPy Arrays to PyCOLMAP + + Different from batch_np_matrix_to_pycolmap, this function does not use tracks. + + It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods. + + Do NOT use this for BA. + """ + # points3d: Px3 + # points_xyf: Px3, with x, y coordinates and frame indices + # points_rgb: Px3, rgb colors + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N = len(extrinsics) + P = len(points3d) + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + for vidx in range(P): + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx]) + + camera = None + # frame idx + for fidx in range(N): + # set camera + if camera is None or (not shared_camera): + pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type) + + camera = pycolmap.Camera( + model=camera_type, width=image_size[0], height=image_size[1], params=pycolmap_intri, camera_id=fidx + 1 + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + + image = pycolmap.Image( + id=fidx + 1, name=f"image_{fidx + 1}", camera_id=camera.camera_id, cam_from_world=cam_from_world + ) + + points2D_list = [] + + point2D_idx = 0 + + points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx + points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0] + + for point3D_batch_idx in points_belong_to_fidx: + point3D_id = point3D_batch_idx + 1 + point2D_xyf = points_xyf[point3D_batch_idx] + point2D_xy = point2D_xyf[:2] + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + track.add_element(fidx + 1, point2D_idx) + point2D_idx += 1 + + assert point2D_idx == len(points2D_list) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + image.registered = True + except: + print(f"frame {fidx + 1} does not have any points") + image.registered = False + + # add image + reconstruction.add_image(image) + + return reconstruction + + +def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None): + """ + Helper function to get camera parameters based on camera type. + + Args: + fidx: Frame index + intrinsics: Camera intrinsic parameters + camera_type: Type of camera model + extra_params: Additional parameters for certain camera types + + Returns: + pycolmap_intri: NumPy array of camera parameters + """ + if camera_type == "PINHOLE": + pycolmap_intri = np.array( + [intrinsics[fidx][0, 0], intrinsics[fidx][1, 1], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]] + ) + elif camera_type == "SIMPLE_PINHOLE": + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]) + elif camera_type == "SIMPLE_RADIAL": + raise NotImplementedError("SIMPLE_RADIAL is not supported yet") + focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2 + pycolmap_intri = np.array([focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], extra_params[fidx][0]]) + else: + raise ValueError(f"Camera type {camera_type} is not supported yet") + + return pycolmap_intri diff --git a/vggt/vggt/dependency/projection.py b/vggt/vggt/dependency/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..a98082dc2f5b3c057b398a03ab13dba470f4a111 --- /dev/null +++ b/vggt/vggt/dependency/projection.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from .distortion import apply_distortion + + +def img_from_cam_np( + intrinsics: np.ndarray, points_cam: np.ndarray, extra_params: np.ndarray | None = None, default: float = 0.0 +) -> np.ndarray: + """ + Apply intrinsics (and optional radial distortion) to camera-space points. + + Args + ---- + intrinsics : (B,3,3) camera matrix K. + points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ. + extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None. + default : value used for np.nan replacement. + + Returns + ------- + points2D : (B,N,2) pixel coordinates. + """ + # 1. perspective divide ─────────────────────────────────────── + z = points_cam[:, 2:3, :] # (B,1,N) + points_cam_norm = points_cam / z # (B,3,N) + uv = points_cam_norm[:, :2, :] # (B,2,N) + + # 2. optional distortion ────────────────────────────────────── + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = np.stack([uu, vv], axis=1) # (B,2,N) + + # 3. homogeneous coords then K multiplication ───────────────── + ones = np.ones_like(uv[:, :1, :]) # (B,1,N) + points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N) + + # batched mat-mul: K · [u v 1]ᵀ + points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N) + points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N) + + return points2D.transpose(0, 2, 1) # (B,N,2) + + +def project_3D_points_np( + points3D: np.ndarray, + extrinsics: np.ndarray, + intrinsics: np.ndarray | None = None, + extra_params: np.ndarray | None = None, + *, + default: float = 0.0, + only_points_cam: bool = False, +): + """ + NumPy clone of ``project_3D_points``. + + Parameters + ---------- + points3D : (N,3) world-space points. + extrinsics : (B,3,4) [R|t] matrix for each of B cameras. + intrinsics : (B,3,3) K matrix (optional if you only need cam-space). + extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None. + default : value used to replace NaNs. + only_points_cam : if True, skip the projection and return points_cam with points2D as None. + + Returns + ------- + (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True, + and points_cam is (B,3,N) camera-space coordinates. + """ + # ----- 0. prep sizes ----------------------------------------------------- + N = points3D.shape[0] # #points + B = extrinsics.shape[0] # #cameras + + # ----- 1. world → homogeneous ------------------------------------------- + w_h = np.ones((N, 1), dtype=points3D.dtype) + points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4) + + # broadcast to every camera (no actual copying with np.broadcast_to) ------ + points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4) + + # ----- 2. apply extrinsics (camera frame) ------------------------------ + # X_cam = E · X_hom + # einsum: E_(b i j) · X_(b n j) → (b n i) + points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3) + points_cam = points_cam.transpose(0, 2, 1) # (B,3,N) + + if only_points_cam: + return None, points_cam + + # ----- 3. intrinsics + distortion --------------------------------------- + if intrinsics is None: + raise ValueError("`intrinsics` must be provided unless only_points_cam=True") + + points2D = img_from_cam_np(intrinsics, points_cam, extra_params=extra_params, default=default) + + return points2D, points_cam + + +def project_3D_points(points3D, extrinsics, intrinsics=None, extra_params=None, default=0, only_points_cam=False): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + points3D (torch.Tensor): 3D points of shape Px3. + extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion. + default (float): Default value to replace NaNs. + only_points_cam (bool): If True, skip the projection and return points2D as None. + + Returns: + tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True, + and points_cam is of shape Bx3xN. + """ + with torch.cuda.amp.autocast(dtype=torch.double): + N = points3D.shape[0] # Number of points + B = extrinsics.shape[0] # Batch size, i.e., number of cameras + points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4 + # Reshape for batch processing + points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) + + if only_points_cam: + return None, points_cam + + # Step 2: Apply intrinsic parameters and (optional) distortion + points2D = img_from_cam(intrinsics, points_cam, extra_params, default) + + return points2D, points_cam + + +def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalize by the third coordinate (homogeneous division) + points_cam = points_cam / points_cam[:, 2:3, :] + # Extract uv + uv = points_cam[:, :2, :] + + # Apply distortion if extra_params are provided + if extra_params is not None: + uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1]) + uv = torch.stack([uu, vv], dim=1) + + # Prepare points_cam for batch matrix multiplication + points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN + + # Extract x and y coordinates + points2D = points2D_homo[:, :2, :] # Bx2xN + + # Replace NaNs with default value + points2D = torch.nan_to_num(points2D, nan=default) + + return points2D.transpose(1, 2) # BxNx2 + + +if __name__ == "__main__": + # Set up example input + B, N = 24, 10240 + + for _ in range(100): + points3D = np.random.rand(N, 3).astype(np.float64) + extrinsics = np.random.rand(B, 3, 4).astype(np.float64) + intrinsics = np.random.rand(B, 3, 3).astype(np.float64) + + # Convert to torch tensors + points3D_torch = torch.tensor(points3D) + extrinsics_torch = torch.tensor(extrinsics) + intrinsics_torch = torch.tensor(intrinsics) + + # Run NumPy implementation + points2D_np, points_cam_np = project_3D_points_np(points3D, extrinsics, intrinsics) + + # Run torch implementation + points2D_torch, points_cam_torch = project_3D_points(points3D_torch, extrinsics_torch, intrinsics_torch) + + # Convert torch output to numpy + points2D_torch_np = points2D_torch.detach().numpy() + points_cam_torch_np = points_cam_torch.detach().numpy() + + # Compute difference + diff = np.abs(points2D_np - points2D_torch_np) + print("Difference between NumPy and PyTorch implementations:") + print(diff) + + # Check max error + max_diff = np.max(diff) + print(f"Maximum difference: {max_diff}") + + if np.allclose(points2D_np, points2D_torch_np, atol=1e-6): + print("Implementations match closely.") + else: + print("Significant differences detected.") + + if points_cam_np is not None: + points_cam_diff = np.abs(points_cam_np - points_cam_torch_np) + print("Difference between NumPy and PyTorch camera-space coordinates:") + print(points_cam_diff) + + # Check max error + max_cam_diff = np.max(points_cam_diff) + print(f"Maximum camera-space coordinate difference: {max_cam_diff}") + + if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6): + print("Camera-space coordinates match closely.") + else: + print("Significant differences detected in camera-space coordinates.") diff --git a/vggt/vggt/dependency/track_modules/__init__.py b/vggt/vggt/dependency/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/vggt/dependency/track_modules/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/dependency/track_modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c66becc6111491c93fe1b7b285f58f407e4437a8 Binary files /dev/null and b/vggt/vggt/dependency/track_modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/track_modules/__pycache__/base_track_predictor.cpython-39.pyc b/vggt/vggt/dependency/track_modules/__pycache__/base_track_predictor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4deda8da68b292e97dca42abc3cbe66c5560d78b Binary files /dev/null and b/vggt/vggt/dependency/track_modules/__pycache__/base_track_predictor.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/track_modules/__pycache__/blocks.cpython-39.pyc b/vggt/vggt/dependency/track_modules/__pycache__/blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d3c19bdd60f71fc574e7bcf2017d17e7259b42d Binary files /dev/null and b/vggt/vggt/dependency/track_modules/__pycache__/blocks.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/track_modules/__pycache__/modules.cpython-39.pyc b/vggt/vggt/dependency/track_modules/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a06bff56f15560450f336877b9b07b7f3410602 Binary files /dev/null and b/vggt/vggt/dependency/track_modules/__pycache__/modules.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/track_modules/__pycache__/track_refine.cpython-39.pyc b/vggt/vggt/dependency/track_modules/__pycache__/track_refine.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a095632405fbed936251249497967996231a144 Binary files /dev/null and b/vggt/vggt/dependency/track_modules/__pycache__/track_refine.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/track_modules/__pycache__/utils.cpython-39.pyc b/vggt/vggt/dependency/track_modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83988be247eb2ed18689a4f4e8b8fd5aa529fe85 Binary files /dev/null and b/vggt/vggt/dependency/track_modules/__pycache__/utils.cpython-39.pyc differ diff --git a/vggt/vggt/dependency/track_modules/base_track_predictor.py b/vggt/vggt/dependency/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..8218c014e20baa646b612e368d8bdd1841658d65 --- /dev/null +++ b/vggt/vggt/dependency/track_modules/base_track_predictor.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=4, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + fine=False, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.fine = fine + + self.flows_emb_dim = latent_dim // 2 + self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 + + if self.fine: + # TODO this is the old dummy code, will remove this when we train next model + self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 + else: + self.transformer_dim += (4 - self.transformer_dim % 4) % 4 + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + if not self.fine: + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2 + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + # Construct the correlation block + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for itr in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + # Compute the correlation (check the implementation of CorrBlock) + + fcorr_fn.corr(track_feats) + fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim + + corrdim = fcorrs.shape[3] + + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + if transformer_input.shape[2] < self.transformer_dim: + # pad the features to match the dimension + pad_dim = self.transformer_dim - transformer_input.shape[2] + pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) + transformer_input = torch.cat([transformer_input, pad], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta = self.updateformer(x) + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + if not self.fine: + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + vis_e = torch.sigmoid(vis_e) + else: + vis_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat + else: + return coord_preds, vis_e diff --git a/vggt/vggt/dependency/track_modules/blocks.py b/vggt/vggt/dependency/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e0017d2c25338d0ce5d3f31e3802282259c8fa36 --- /dev/null +++ b/vggt/vggt/dependency/track_modules/blocks.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + +from .utils import bilinear_sampler + +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=4): + super(BasicEncoder, self).__init__() + + self.stride = stride + self.norm_fn = "instance" + self.in_planes = output_dim // 2 + + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(output_dim // 2, stride=1) + self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) + self.layer3 = self._make_layer(output_dim, stride=2) + self.layer4 = self._make_layer(output_dim, stride=2) + + self.conv2 = nn.Conv2d( + output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + + a = _bilinear_intepolate(a, self.stride, H, W) + b = _bilinear_intepolate(b, self.stride, H, W) + c = _bilinear_intepolate(c, self.stride, H, W) + d = _bilinear_intepolate(d, self.stride, H, W) + + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + return x + + +class ShallowEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): + super(ShallowEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = output_dim + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(output_dim, stride=2) + + self.layer2 = self._make_layer(output_dim, stride=2) + self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + self.in_planes = dim + + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + return layer1 + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + tmp = self.layer1(x) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = self.layer2(tmp) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = None + x = self.conv2(x) + x + + x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) + + return x + + +def _bilinear_intepolate(x, stride, H, W): + return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, input_tensor, mask=None): + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + flow = self.flow_head(tokens) + return flow + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.padding_mode = padding_mode + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.multiple_track_feats = multiple_track_feats + + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + *_, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) + corrs = corrs.view(B, S, N, -1) + + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 + return out + + def corr(self, targets): + B, S, N, C = targets.shape + if self.multiple_track_feats: + targets_split = targets.split(C // self.num_levels, dim=-1) + B, S, N, C = targets_split[0].shape + + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + *_, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) + if self.multiple_track_feats: + fmap1 = targets_split[i] + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) diff --git a/vggt/vggt/dependency/track_modules/modules.py b/vggt/vggt/dependency/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e89b26edc7717f04a897977041f26e5c4f1c52b2 --- /dev/null +++ b/vggt/vggt/dependency/track_modules/modules.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/vggt/vggt/dependency/track_modules/track_refine.py b/vggt/vggt/dependency/track_modules/track_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..54a7ace1d49686304e5fbf28c33168667c28e181 --- /dev/null +++ b/vggt/vggt/dependency/track_modules/track_refine.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from PIL import Image +import os +from typing import Union, Tuple + + +def refine_track( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, chunk=40960 +): + """ + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + if chunk < 0: + # Extract image patches based on top left corners + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + else: + patches = extracted_patches.reshape(B * S * N, C_in, psize, psize) + + patch_feat_list = [] + for p in torch.split(patches, chunk): + patch_feat_list += [fine_fnet(p)] + patch_feat = torch.cat(patch_feat_list, 0) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +def refine_track_v0( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6 +): + """ + COPIED FROM VGGSfM + + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # Extract image patches based on top left corners + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +################################## NOTE: NOT USED ################################## + + +def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out): + """ + Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, + given the query point features and reference frame feature maps + """ + + from kornia.utils.grid import create_meshgrid + from kornia.geometry.subpix import dsnt + + # query_point_feat initial shape: B x N x C_out, + # query_point_feat indicates the feat at the coorponsing query points + # Therefore we don't have S dimension here + query_point_feat = query_point_feat.reshape(B, N, C_out) + # reshape and expand to B x (S-1) x N x C_out + query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) + # and reshape to (B*(S-1)*N) x C_out + query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) + + # Radius and size for computing the score + ssize = sradius * 2 + 1 + + # Reshape, you know it, so many reshaping operations + patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) + + # Again, we unfold the patches to smaller patches + # so that we can then focus on smaller patches + # patch_feat_unfold shape: + # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize + # well a bit scary, but actually not + patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) + + # Do the same stuffs above, i.e., the same as extracting patches + fine_prediction_floor = fine_pred_track.floor().int() + fine_level_floor_topleft = fine_prediction_floor - sradius + + # Clamp to ensure the smaller patch is valid + fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) + fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) + + # Prepare the batch indices and xy locations + + batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN + batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N + y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices + x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices + + reference_frame_feat = patch_feat_unfold.reshape( + B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize + ) + + # Note again, according to pytorch convention + # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] + reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices] + reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) + # pick the frames other than the first one, so we have S-1 frames here + reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize) + + # Compute similarity + sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) + softmax_temp = 1.0 / C_out**0.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + # 2D heatmaps + heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape( + 1, -1, 2 + ) + + var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2 + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability + + score = std.reshape(B, S - 1, N) + # set score as 1 for the query frame + score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) + + return score + + +def extract_glimpse( + tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None +): + B, C, W, H = tensor.shape + + h, w = size + xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 + ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 + + vy, vx = torch.meshgrid(ys, xs) + grid = torch.stack([vx, vy], dim=-1) # h, w, 2 + grid = grid[None] + + B, N, _ = offsets.shape + + offsets = offsets.reshape((B * N), 1, 1, 2) + offsets_grid = offsets + grid + + # normalised grid to [-1, 1] + offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2]) + + # BxCxHxW -> Bx1xCxHxW + tensor = tensor[:, None] + + # Bx1xCxHxW -> BxNxCxHxW + tensor = tensor.expand(-1, N, -1, -1, -1) + + # BxNxCxHxW -> (B*N)xCxHxW + tensor = tensor.reshape((B * N), C, W, H) + + sampled = torch.nn.functional.grid_sample( + tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode + ) + + # NOTE: I am not sure it should be h, w or w, h here + # but okay for sqaures + sampled = sampled.reshape(B, N, C, h, w) + + return sampled diff --git a/vggt/vggt/dependency/track_modules/utils.py b/vggt/vggt/dependency/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8954e87beb85e71c5fa4b5d7eb4f2b476680e6f --- /dev/null +++ b/vggt/vggt/dependency/track_modules/utils.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/PoseDiffusion +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union +from einops import rearrange, repeat + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device) + else: + coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) + + coords -= 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/vggt/vggt/dependency/track_predict.py b/vggt/vggt/dependency/track_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..c15c23fea612acb9383d7f03d7779b6d0f2dbf82 --- /dev/null +++ b/vggt/vggt/dependency/track_predict.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from .vggsfm_utils import * + + +def predict_tracks( + images, + conf=None, + points_3d=None, + masks=None, + max_query_pts=2048, + query_frame_num=5, + keypoint_extractor="aliked+sp", + max_points_num=163840, + fine_tracking=True, + complete_non_vis=True, +): + """ + Predict tracks for the given images and masks. + + TODO: support non-square images + TODO: support masks + + + This function predicts the tracks for the given images and masks using the specified query method + and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. + + Args: + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. + points_3d: Tensor containing 3D points. Default is None. + masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None. + max_query_pts: Maximum number of query points. Default is 2048. + query_frame_num: Number of query frames to use. Default is 5. + keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". + max_points_num: Maximum number of points to process at once. Default is 163840. + fine_tracking: Whether to use fine tracking. Default is True. + complete_non_vis: Whether to augment non-visible frames. Default is True. + + Returns: + pred_tracks: Numpy array containing the predicted tracks. + pred_vis_scores: Numpy array containing the visibility scores for the tracks. + pred_confs: Numpy array containing the confidence scores for the tracks. + pred_points_3d: Numpy array containing the 3D points for the tracks. + pred_colors: Numpy array containing the point colors for the tracks. (0, 255) + """ + + device = images.device + dtype = images.dtype + tracker = build_vggsfm_tracker().to(device, dtype) + + # Find query frames + query_frame_indexes = generate_rank_by_dino(images, query_frame_num=query_frame_num, device=device) + + # Add the first image to the front if not already present + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + # TODO: add the functionality to handle the masks + keypoint_extractors = initialize_feature_extractors( + max_query_pts, extractor_method=keypoint_extractor, device=device + ) + + pred_tracks = [] + pred_vis_scores = [] + pred_confs = [] + pred_points_3d = [] + pred_colors = [] + + fmaps_for_tracker = tracker.process_images_to_fmaps(images) + + if fine_tracking: + print("For faster inference, consider disabling fine_tracking") + + for query_index in query_frame_indexes: + print(f"Predicting tracks for query frame {query_index}") + pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + pred_confs.append(pred_conf) + pred_points_3d.append(pred_point_3d) + pred_colors.append(pred_color) + + if complete_non_vis: + pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = _augment_non_visible_frames( + pred_tracks, + pred_vis_scores, + pred_confs, + pred_points_3d, + pred_colors, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + min_vis=500, + non_vis_thresh=0.1, + device=device, + ) + + pred_tracks = np.concatenate(pred_tracks, axis=1) + pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) + pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None + pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None + pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + # from vggt.utils.visual_track import visualize_tracks_on_images + # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors + + +def _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num, + fine_tracking, + device, +): + """ + Process a single query frame for track prediction. + + Args: + query_index: Index of the query frame + images: Tensor of shape [S, 3, H, W] containing the input images + conf: Confidence tensor + points_3d: 3D points tensor + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + device: Device to use for computation + + Returns: + pred_track: Predicted tracks + pred_vis: Visibility scores for the tracks + pred_conf: Confidence scores for the tracks + pred_point_3d: 3D points for the tracks + pred_color: Point colors for the tracks (0, 255) + """ + frame_num, _, height, width = images.shape + + query_image = images[query_index] + query_points = extract_keypoints(query_image, keypoint_extractors, round_keypoints=False) + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + + # Extract the color at the keypoint locations + query_points_long = query_points.squeeze(0).round().long() + pred_color = images[query_index][:, query_points_long[:, 1], query_points_long[:, 0]] + pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) + + # Query the confidence and points_3d at the keypoint locations + if (conf is not None) and (points_3d is not None): + assert height == width + assert conf.shape[-2] == conf.shape[-1] + assert conf.shape[:3] == points_3d.shape[:3] + scale = conf.shape[-1] / width + + query_points_scaled = (query_points.squeeze(0) * scale).round().long() + query_points_scaled = query_points_scaled.cpu().numpy() + + pred_conf = conf[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] + pred_point_3d = points_3d[query_index][query_points_scaled[:, 1], query_points_scaled[:, 0]] + + # heuristic to remove low confidence points + # should I export this as an input parameter? + valid_mask = pred_conf > 1.2 + if valid_mask.sum() > 512: + query_points = query_points[:, valid_mask] # Make sure shape is compatible + pred_conf = pred_conf[valid_mask] + pred_point_3d = pred_point_3d[valid_mask] + pred_color = pred_color[valid_mask] + else: + pred_conf = None + pred_point_3d = None + + reorder_index = calculate_index_mappings(query_index, frame_num, device=device) + + images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], reorder_index, dim=0) + images_feed = images_feed[None] # add batch dimension + fmaps_feed = fmaps_feed[None] # add batch dimension + + all_points_num = images_feed.shape[1] * query_points.shape[1] + + # Don't need to be scared, this is just chunking to make GPU happy + if all_points_num > max_points_num: + num_splits = (all_points_num + max_points_num - 1) // max_points_num + query_points = torch.chunk(query_points, num_splits, dim=1) + else: + query_points = [query_points] + + pred_track, pred_vis, _ = predict_tracks_in_chunks( + tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking + ) + + pred_track, pred_vis = switch_tensor_order([pred_track, pred_vis], reorder_index, dim=1) + + pred_track = pred_track.squeeze(0).float().cpu().numpy() + pred_vis = pred_vis.squeeze(0).float().cpu().numpy() + + return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color + + +def _augment_non_visible_frames( + pred_tracks: list, # ← running list of np.ndarrays + pred_vis_scores: list, # ← running list of np.ndarrays + pred_confs: list, # ← running list of np.ndarrays for confidence scores + pred_points_3d: list, # ← running list of np.ndarrays for 3D points + pred_colors: list, # ← running list of np.ndarrays for colors + images: torch.Tensor, + conf, + points_3d, + fmaps_for_tracker, + keypoint_extractors, + tracker, + max_points_num: int, + fine_tracking: bool, + *, + min_vis: int = 500, + non_vis_thresh: float = 0.1, + device: torch.device = None, +): + """ + Augment tracking for frames with insufficient visibility. + + Args: + pred_tracks: List of numpy arrays containing predicted tracks. + pred_vis_scores: List of numpy arrays containing visibility scores. + pred_confs: List of numpy arrays containing confidence scores. + pred_points_3d: List of numpy arrays containing 3D points. + pred_colors: List of numpy arrays containing point colors. + images: Tensor of shape [S, 3, H, W] containing the input images. + conf: Tensor of shape [S, 1, H, W] containing confidence scores + points_3d: Tensor containing 3D points + fmaps_for_tracker: Feature maps for the tracker + keypoint_extractors: Initialized feature extractors + tracker: VGG-SFM tracker + max_points_num: Maximum number of points to process at once + fine_tracking: Whether to use fine tracking + min_vis: Minimum visibility threshold + non_vis_thresh: Non-visibility threshold + device: Device to use for computation + + Returns: + Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. + """ + last_query = -1 + final_trial = False + cur_extractors = keypoint_extractors # may be replaced on the final trial + + while True: + # Visibility per frame + vis_array = np.concatenate(pred_vis_scores, axis=1) + + # Count frames with sufficient visibility using numpy + sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) + non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() + + if len(non_vis_frames) == 0: + break + + print("Processing non visible frames:", non_vis_frames) + + # Decide the frames & extractor for this round + if non_vis_frames[0] == last_query: + # Same frame failed twice - final "all-in" attempt + final_trial = True + cur_extractors = initialize_feature_extractors(2048, extractor_method="sp+sift+aliked", device=device) + query_frame_list = non_vis_frames # blast them all at once + else: + query_frame_list = [non_vis_frames[0]] # Process one at a time + + last_query = non_vis_frames[0] + + # Run the tracker for every selected frame + for query_index in query_frame_list: + new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( + query_index, + images, + conf, + points_3d, + fmaps_for_tracker, + cur_extractors, + tracker, + max_points_num, + fine_tracking, + device, + ) + pred_tracks.append(new_track) + pred_vis_scores.append(new_vis) + pred_confs.append(new_conf) + pred_points_3d.append(new_point_3d) + pred_colors.append(new_color) + + if final_trial: + break # Stop after final attempt + + return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/vggt/vggt/dependency/vggsfm_tracker.py b/vggt/vggt/dependency/vggsfm_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..d79aeef000dcfec506dc4afb4e500d22a758122b --- /dev/null +++ b/vggt/vggt/dependency/vggsfm_tracker.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from hydra.utils import instantiate +from omegaconf import OmegaConf + +from .track_modules.track_refine import refine_track +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackerPredictor(nn.Module): + def __init__(self, **extra_args): + super(TrackerPredictor, self).__init__() + """ + Initializes the tracker predictor. + + Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, + check track_modules/base_track_predictor.py + + Both coarse_fnet and fine_fnet are constructed as a 2D CNN network + check track_modules/blocks.py for BasicEncoder and ShallowEncoder + """ + # Define coarse predictor configuration + coarse_stride = 4 + self.coarse_down_ratio = 2 + + # Create networks directly instead of using instantiate + self.coarse_fnet = BasicEncoder(stride=coarse_stride) + self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) + + # Create fine predictor with stride = 1 + self.fine_fnet = ShallowEncoder(stride=1) + self.fine_predictor = BaseTrackerPredictor( + stride=1, + depth=4, + corr_levels=3, + corr_radius=3, + latent_dim=32, + hidden_size=256, + fine=True, + use_spaceatt=False, + ) + + def forward( + self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960 + ): + """ + Args: + images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. + query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. + fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. + coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. + inference (bool, optional): Whether to perform inference. Defaults to True. + fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. + + Returns: + tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. + """ + + if fmaps is None: + batch_num, frame_num, image_dim, height, width = images.shape + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + fmaps = self.process_images_to_fmaps(reshaped_image) + fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]) + + if inference: + torch.cuda.empty_cache() + + # Coarse prediction + coarse_pred_track_lists, pred_vis = self.coarse_predictor( + query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio + ) + coarse_pred_track = coarse_pred_track_lists[-1] + + if inference: + torch.cuda.empty_cache() + + if fine_tracking: + # Refine the coarse prediction + fine_pred_track, pred_score = refine_track( + images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk + ) + + if inference: + torch.cuda.empty_cache() + else: + fine_pred_track = coarse_pred_track + pred_score = torch.ones_like(pred_vis) + + return fine_pred_track, coarse_pred_track, pred_vis, pred_score + + def process_images_to_fmaps(self, images): + """ + This function processes images for inference. + + Args: + images (torch.Tensor): The images to be processed with shape S x 3 x H x W. + + Returns: + torch.Tensor: The processed feature maps. + """ + if self.coarse_down_ratio > 1: + # whether or not scale down the input images to save memory + fmaps = self.coarse_fnet( + F.interpolate(images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True) + ) + else: + fmaps = self.coarse_fnet(images) + + return fmaps diff --git a/vggt/vggt/dependency/vggsfm_utils.py b/vggt/vggt/dependency/vggsfm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7d9ba6a28da07b7f030a17730f4826feaa828e --- /dev/null +++ b/vggt/vggt/dependency/vggsfm_utils.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pycolmap +import torch +import torch.nn.functional as F +from lightglue import ALIKED, SIFT, SuperPoint + +from .vggsfm_tracker import TrackerPredictor + +# Suppress verbose logging from dependencies +logging.getLogger("dinov2").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", message="xFormers is available") +warnings.filterwarnings("ignore", message="dinov2") + +# Constants +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def build_vggsfm_tracker(model_path=None): + """ + Build and initialize the VGGSfM tracker. + + Args: + model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. + + Returns: + Initialized tracker model in eval mode. + """ + tracker = TrackerPredictor() + + if model_path is None: + default_url = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" + tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) + else: + tracker.load_state_dict(torch.load(model_path)) + + tracker.eval() + return tracker + + +def generate_rank_by_dino( + images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False +): + """ + Generate a ranking of frames using DINO ViT features. + + Args: + images: Tensor of shape (S, 3, H, W) with values in range [0, 1] + query_frame_num: Number of frames to select + image_size: Size to resize images to before processing + model_name: Name of the DINO model to use + device: Device to run the model on + spatial_similarity: Whether to use spatial token similarity or CLS token similarity + + Returns: + List of frame indices ranked by their representativeness + """ + # Resize images to the target size + images = F.interpolate(images, (image_size, image_size), mode="bilinear", align_corners=False) + + # Load DINO model + dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) + dino_v2_model.eval() + dino_v2_model = dino_v2_model.to(device) + + # Normalize images using ResNet normalization + resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) + resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) + images_resnet_norm = (images - resnet_mean) / resnet_std + + with torch.no_grad(): + frame_feat = dino_v2_model(images_resnet_norm, is_training=True) + + # Process features based on similarity type + if spatial_similarity: + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similarity matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + similarity_matrix = similarity_matrix.mean(dim=0) + else: + frame_feat = frame_feat["x_norm_clstoken"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling starting from the most common frame + fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + # Clean up all tensors and models to free memory + del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix + del dino_v2_model + torch.cuda.empty_cache() + + return fps_idx + + +def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): + """ + Farthest point sampling algorithm to select diverse frames. + + Args: + distance_matrix: Matrix of distances between frames + num_samples: Number of frames to select + most_common_frame_index: Index of the first frame to select + + Returns: + List of selected frame indices + """ + distance_matrix = distance_matrix.clamp(min=0) + N = distance_matrix.size(0) + + # Initialize with the most common frame + selected_indices = [most_common_frame_index] + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + # Find the farthest point from the current set of selected points + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + + check_distances = distance_matrix[farthest_point] + # Mark already selected points to avoid selecting them again + check_distances[selected_indices] = 0 + + # Break if all points have been selected + if len(selected_indices) == N: + break + + return selected_indices + + +def calculate_index_mappings(query_index, S, device=None): + """ + Construct an order that switches [query_index] and [0] + so that the content of query_index would be placed at [0]. + + Args: + query_index: Index to swap with 0 + S: Total number of elements + device: Device to place the tensor on + + Returns: + Tensor of indices with the swapped order + """ + new_order = torch.arange(S) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors, order, dim=1): + """ + Reorder tensors along a specific dimension according to the given order. + + Args: + tensors: List of tensors to reorder + order: Tensor of indices specifying the new order + dim: Dimension along which to reorder + + Returns: + List of reordered tensors + """ + return [torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors] + + +def initialize_feature_extractors(max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"): + """ + Initialize feature extractors that can be reused based on a method string. + + Args: + max_query_num: Maximum number of keypoints to extract + det_thres: Detection threshold for keypoint extraction + extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") + device: Device to run extraction on + + Returns: + Dictionary of initialized extractors + """ + extractors = {} + methods = extractor_method.lower().split("+") + + for method in methods: + method = method.strip() + if method == "aliked": + aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["aliked"] = aliked_extractor.to(device).eval() + elif method == "sp": + sp_extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["sp"] = sp_extractor.to(device).eval() + elif method == "sift": + sift_extractor = SIFT(max_num_keypoints=max_query_num) + extractors["sift"] = sift_extractor.to(device).eval() + else: + print(f"Warning: Unknown feature extractor '{method}', ignoring.") + + if not extractors: + print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") + aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) + extractors["aliked"] = aliked_extractor.to(device).eval() + + return extractors + + +def extract_keypoints(query_image, extractors, round_keypoints=True): + """ + Extract keypoints using pre-initialized feature extractors. + + Args: + query_image: Input image tensor (3xHxW, range [0, 1]) + extractors: Dictionary of initialized extractors + + Returns: + Tensor of keypoint coordinates (1xNx2) + """ + query_points = None + + with torch.no_grad(): + for extractor_name, extractor in extractors.items(): + query_points_data = extractor.extract(query_image, invalid_mask=None) + extractor_points = query_points_data["keypoints"] + if round_keypoints: + extractor_points = extractor_points.round() + + if query_points is not None: + query_points = torch.cat([query_points, extractor_points], dim=1) + else: + query_points = extractor_points + + return query_points + + +def predict_tracks_in_chunks( + track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960 +): + """ + Process a list of query points to avoid memory issues. + + Args: + track_predictor (object): The track predictor object used for predicting tracks. + images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. + query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. + fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. + fine_tracking (bool): Whether to perform fine tracking. + num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. + + Returns: + tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. + """ + # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility + if not isinstance(query_points_list, (list, tuple)): + query_points = query_points_list + if num_splits is None: + num_splits = 1 + query_points_list = torch.chunk(query_points, num_splits, dim=1) + + # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) + if isinstance(query_points_list, tuple): + query_points_list = list(query_points_list) + + fine_pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for split_points in query_points_list: + # Feed into track predictor for each split + fine_pred_track, _, pred_vis, pred_score = track_predictor( + images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk + ) + fine_pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + # Concatenate the results from all splits + fine_pred_track = torch.cat(fine_pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + + if pred_score is not None: + pred_score = torch.cat(pred_score_list, dim=2) + else: + pred_score = None + + return fine_pred_track, pred_vis, pred_score diff --git a/vggt/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/vggt/vggt/heads/__pycache__/camera_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31eb8e8ee390a0985f535cdacd4042c445e09306 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/camera_head.cpython-310.pyc differ diff --git a/vggt/vggt/heads/__pycache__/camera_head.cpython-39.pyc b/vggt/vggt/heads/__pycache__/camera_head.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4fb8ba191baaf4cbb61cd3f4801ed8539fefd7f Binary files /dev/null and b/vggt/vggt/heads/__pycache__/camera_head.cpython-39.pyc differ diff --git a/vggt/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/vggt/vggt/heads/__pycache__/dpt_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad33b69d2deb640a9591c9619e4d1189bd3f2a93 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/dpt_head.cpython-310.pyc differ diff --git a/vggt/vggt/heads/__pycache__/dpt_head.cpython-39.pyc b/vggt/vggt/heads/__pycache__/dpt_head.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0817c6adb33307f4d6b074709945ca4d80fd5e63 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/dpt_head.cpython-39.pyc differ diff --git a/vggt/vggt/heads/__pycache__/head_act.cpython-310.pyc b/vggt/vggt/heads/__pycache__/head_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec8178095bb2f5688e50ccb6419a4f8adb5014a Binary files /dev/null and b/vggt/vggt/heads/__pycache__/head_act.cpython-310.pyc differ diff --git a/vggt/vggt/heads/__pycache__/head_act.cpython-39.pyc b/vggt/vggt/heads/__pycache__/head_act.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64bea500b1846fd5374a9e2dfa90514bf4029735 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/head_act.cpython-39.pyc differ diff --git a/vggt/vggt/heads/__pycache__/track_head.cpython-310.pyc b/vggt/vggt/heads/__pycache__/track_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cece504f5292edd5a26c0b3d8cfbe53c42fea1c7 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/track_head.cpython-310.pyc differ diff --git a/vggt/vggt/heads/__pycache__/track_head.cpython-39.pyc b/vggt/vggt/heads/__pycache__/track_head.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ae245e53e88387e8ef6a7e1b3d0f3d2556053d3 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/track_head.cpython-39.pyc differ diff --git a/vggt/vggt/heads/__pycache__/utils.cpython-310.pyc b/vggt/vggt/heads/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2ec3c5e59a9b5fdbe5a9cec499edead369aaa30 Binary files /dev/null and b/vggt/vggt/heads/__pycache__/utils.cpython-310.pyc differ diff --git a/vggt/vggt/heads/__pycache__/utils.cpython-39.pyc b/vggt/vggt/heads/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..190df340f565953f4392660734ef795f4dd9efba Binary files /dev/null and b/vggt/vggt/heads/__pycache__/utils.cpython-39.pyc differ diff --git a/vggt/vggt/heads/camera_head.py b/vggt/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1ffb57d6c675dd7ef6166deaf4e9a3354b68dd --- /dev/null +++ b/vggt/vggt/heads/camera_head.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vggt.layers import Mlp +from vggt.layers.block import Block +from vggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/vggt/vggt/heads/dpt_head.py b/vggt/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..73978a87bf3ff134e53076ad20135bfac3045341 --- /dev/null +++ b/vggt/vggt/heads/dpt_head.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch(out_channels, features, expand=False) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/vggt/vggt/heads/head_act.py b/vggt/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489 --- /dev/null +++ b/vggt/vggt/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/vggt/vggt/heads/track_head.py b/vggt/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f1d9bd83cca1f74f97a644a02b984904f84706 --- /dev/null +++ b/vggt/vggt/heads/track_head.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) + + return coord_preds, vis_scores, conf_scores diff --git a/vggt/vggt/heads/track_modules/__init__.py b/vggt/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/vggt/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9873ea6e34f239589e43f4f4a0c639be11de4f3d Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88988d4f2335d6d5425977304bf9725162551ab8 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c2d60885293341fc385d15365f2d471cccffdf3 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-39.pyc b/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc7365a9ece4a8dbe1ba0b84b3ea447917edf1da Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-39.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1325a6c8d5d0ef1b9918be4b2528a9f712a97b77 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-39.pyc b/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58cc6174e84609804d729c799080824daf259061 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/blocks.cpython-39.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5ca65bc39dab6c583ff9c226d8ecbb7f696d784 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-39.pyc b/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0e7a502b8d90a34f68773b6987e746a619a11a Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/modules.cpython-39.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5240ab56937b212a163453df4d5905cbc8e702 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc differ diff --git a/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-39.pyc b/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0aad508a20ccf594c72c5a9ac77942798a4aee0 Binary files /dev/null and b/vggt/vggt/heads/track_modules/__pycache__/utils.cpython-39.pyc differ diff --git a/vggt/vggt/heads/track_modules/base_track_predictor.py b/vggt/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a --- /dev/null +++ b/vggt/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed +from .modules import Mlp + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/vggt/vggt/heads/track_modules/blocks.py b/vggt/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..15c161c89ef99742b0f2c6f397c9121fe9301e08 --- /dev/null +++ b/vggt/vggt/heads/track_modules/blocks.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/vggt/vggt/heads/track_modules/modules.py b/vggt/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..12de4f1ad76364d4665e53ac80e1037fadf98d08 --- /dev/null +++ b/vggt/vggt/heads/track_modules/modules.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/vggt/vggt/heads/track_modules/utils.py b/vggt/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2 --- /dev/null +++ b/vggt/vggt/heads/track_modules/utils.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype + ) + else: + scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/vggt/vggt/heads/utils.py b/vggt/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f --- /dev/null +++ b/vggt/vggt/heads/utils.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/vggt/vggt/layers/__init__.py b/vggt/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/vggt/vggt/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/vggt/vggt/layers/__pycache__/__init__.cpython-310.pyc b/vggt/vggt/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92ecdcc837d9d7e3f0ced7a964420d850336c34e Binary files /dev/null and b/vggt/vggt/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/layers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4140820252244685a2a3422f26e604d7f9c13f6 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/attention.cpython-310.pyc b/vggt/vggt/layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83405a0f9ff48c9e194dfdefb9b50bf62e94723b Binary files /dev/null and b/vggt/vggt/layers/__pycache__/attention.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/attention.cpython-39.pyc b/vggt/vggt/layers/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56f8e7ef17204f4c3e455c302a453331d73e9b86 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/attention.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/block.cpython-310.pyc b/vggt/vggt/layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97e1f778f29a0269cfffc55207f17740c08ddb9a Binary files /dev/null and b/vggt/vggt/layers/__pycache__/block.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/block.cpython-39.pyc b/vggt/vggt/layers/__pycache__/block.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c010bddc5a0b5cc0159ea6b642898cfe1ffb5d7 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/block.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/vggt/vggt/layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf1d437aaf169e4d92e8187bce6eca663eddb33a Binary files /dev/null and b/vggt/vggt/layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/drop_path.cpython-39.pyc b/vggt/vggt/layers/__pycache__/drop_path.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41364d5c55961b983c4298cb0921dabe6b97eaee Binary files /dev/null and b/vggt/vggt/layers/__pycache__/drop_path.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/vggt/vggt/layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7e763db9f928c206b06745c5d4407df91111ee3 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/layer_scale.cpython-39.pyc b/vggt/vggt/layers/__pycache__/layer_scale.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e8214207e05c6c935231762ebd94f23ef1eb33 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/layer_scale.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/mlp.cpython-310.pyc b/vggt/vggt/layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a15d82768acf3f70e227312af1469b32d59f9e2a Binary files /dev/null and b/vggt/vggt/layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/mlp.cpython-39.pyc b/vggt/vggt/layers/__pycache__/mlp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..174523e274688dd59b03dbc1a49d831603c72ac1 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/mlp.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/vggt/vggt/layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7822266262b17fa4eb4ae8de7efa60edff101a13 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/patch_embed.cpython-39.pyc b/vggt/vggt/layers/__pycache__/patch_embed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb56a567edc493bd2d178c3512d007412ef0140a Binary files /dev/null and b/vggt/vggt/layers/__pycache__/patch_embed.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/rope.cpython-310.pyc b/vggt/vggt/layers/__pycache__/rope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2c8f69444e724e1fad379c21b3c4e07dc3b7fa Binary files /dev/null and b/vggt/vggt/layers/__pycache__/rope.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/rope.cpython-39.pyc b/vggt/vggt/layers/__pycache__/rope.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab7410d4b0443777e653cb55179e3abbf3764942 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/rope.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d6889668b6d3701e96d951c9f16706d901d315f Binary files /dev/null and b/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-39.pyc b/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfd92a7848300cd0fdfd2b5facd382b76ede3dce Binary files /dev/null and b/vggt/vggt/layers/__pycache__/swiglu_ffn.cpython-39.pyc differ diff --git a/vggt/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/vggt/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e9df8edac1d1ccabf910f28dbeb1d8d9aa094cf Binary files /dev/null and b/vggt/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/vggt/vggt/layers/__pycache__/vision_transformer.cpython-39.pyc b/vggt/vggt/layers/__pycache__/vision_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2069e2db8c2c1cd765af9ec751ffd15b4bd056d3 Binary files /dev/null and b/vggt/vggt/layers/__pycache__/vision_transformer.cpython-39.pyc differ diff --git a/vggt/vggt/layers/attention.py b/vggt/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8e823b4b7a93cca75e4cbab1cdfbbc3121a316fa --- /dev/null +++ b/vggt/vggt/layers/attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: + assert pos is None + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/vggt/vggt/layers/block.py b/vggt/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5847352a1f8f5d63da28c99e94270e50ccf3aa --- /dev/null +++ b/vggt/vggt/layers/block.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + x = drop_add_residual_stochastic_depth( + x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None), + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/vggt/vggt/layers/drop_path.py b/vggt/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/vggt/vggt/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/vggt/vggt/layers/layer_scale.py b/vggt/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddfc51c3d87370d50175f5b4e649dac1c614ff9 --- /dev/null +++ b/vggt/vggt/layers/layer_scale.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/vggt/vggt/layers/mlp.py b/vggt/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/vggt/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/vggt/vggt/layers/patch_embed.py b/vggt/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..bc19605e4d6e88d06355ae3b1afddc76f595aafe --- /dev/null +++ b/vggt/vggt/layers/patch_embed.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/vggt/vggt/layers/rope.py b/vggt/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df --- /dev/null +++ b/vggt/vggt/layers/rope.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/vggt/vggt/layers/swiglu_ffn.py b/vggt/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd991e1deb87141ccd282098d4b9d38fed6ef25 --- /dev/null +++ b/vggt/vggt/layers/swiglu_ffn.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/vggt/vggt/layers/vision_transformer.py b/vggt/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..deda8fde42b1b5b3340132c9c75338c65c9bea3f --- /dev/null +++ b/vggt/vggt/layers/vision_transformer.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.use_reentrant = False # hardcoded to False + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/vggt/vggt/models/__init__.py b/vggt/vggt/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/vggt/models/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf1d51c15884a0409e547fb35d0e5baadbad1fdc Binary files /dev/null and b/vggt/vggt/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/models/__pycache__/aggregator.cpython-310.pyc b/vggt/vggt/models/__pycache__/aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d730ad751740fcd2a7d38f124cc59e68436fbfd Binary files /dev/null and b/vggt/vggt/models/__pycache__/aggregator.cpython-310.pyc differ diff --git a/vggt/vggt/models/__pycache__/aggregator.cpython-39.pyc b/vggt/vggt/models/__pycache__/aggregator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a662ccf524041653e245aa1c6aaf888d0c1d764c Binary files /dev/null and b/vggt/vggt/models/__pycache__/aggregator.cpython-39.pyc differ diff --git a/vggt/vggt/models/__pycache__/vggt.cpython-310.pyc b/vggt/vggt/models/__pycache__/vggt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44add6b701d6fdab62af273d450fc138f5e4d907 Binary files /dev/null and b/vggt/vggt/models/__pycache__/vggt.cpython-310.pyc differ diff --git a/vggt/vggt/models/__pycache__/vggt.cpython-39.pyc b/vggt/vggt/models/__pycache__/vggt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897dc8f13a97443d4a36daf0eeec9a15ff74f3e1 Binary files /dev/null and b/vggt/vggt/models/__pycache__/vggt.cpython-39.pyc differ diff --git a/vggt/vggt/models/aggregator.py b/vggt/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6b25d6df44a0dbf71b214f5084b2a21fcd087e --- /dev/null +++ b/vggt/vggt/models/aggregator.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from typing import Optional, Tuple, Union, List, Dict, Any + +from vggt.layers import PatchEmbed +from vggt.layers.block import Block +from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + Remember to set model.train() to enable gradient checkpointing to reduce memory usage. + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) + + # Initialize rotary position embedding if frequency > 0 + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): + self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) + + self.use_reentrant = False # hardcoded to False + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten(self.camera_token, B, S) + register_token = slice_expand_and_flatten(self.register_token, B, S) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if self.training: + tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) + else: + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if self.training: + tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) + else: + tokens = self.global_blocks[global_idx](tokens, pos=pos) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/vggt/vggt/models/vggt.py b/vggt/vggt/models/vggt.py new file mode 100644 index 0000000000000000000000000000000000000000..686e6f9d3f9e37769c195258f429a66d927375c0 --- /dev/null +++ b/vggt/vggt/models/vggt.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub + +from vggt.models.aggregator import Aggregator +from vggt.heads.camera_head import CameraHead +from vggt.heads.dpt_head import DPTHead +from vggt.heads.track_head import TrackHead + + +class VGGT(nn.Module, PyTorchModelHubMixin): + def __init__(self, img_size=518, patch_size=14, embed_dim=1024, + enable_camera=True, enable_point=True, enable_depth=True, enable_track=True): + super().__init__() + + self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + + self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None + self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") if enable_point else None + self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") if enable_depth else None + self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) if enable_track else None + + def forward(self, images: torch.Tensor, query_points: torch.Tensor = None): + """ + Forward pass of the VGGT model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. + Shape: [N, 2] or [B, N, 2], where N is the number of query points. + Default: None + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + + If query_points is provided, also includes: + - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates + - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] + - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] + """ + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + + aggregated_tokens_list, patch_start_idx = self.aggregator(images) + + predictions = {} + + with torch.cuda.amp.autocast(enabled=False): + if self.camera_head is not None: + pose_enc_list = self.camera_head(aggregated_tokens_list) + predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration + predictions["pose_enc_list"] = pose_enc_list + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.track_head is not None and query_points is not None: + track_list, vis, conf = self.track_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points + ) + predictions["track"] = track_list[-1] # track of the last iteration + predictions["vis"] = vis + predictions["conf"] = conf + + if not self.training: + predictions["images"] = images # store the images for visualization during inference + + return predictions + diff --git a/vggt/vggt/utils/__init__.py b/vggt/vggt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggt/vggt/utils/__pycache__/__init__.cpython-39.pyc b/vggt/vggt/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e9d54dbb740dbd8f8489eb1daa9bd3d882c6326 Binary files /dev/null and b/vggt/vggt/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/vggt/vggt/utils/__pycache__/geometry.cpython-39.pyc b/vggt/vggt/utils/__pycache__/geometry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6781286c1baec024ad60706d5761489ca952956 Binary files /dev/null and b/vggt/vggt/utils/__pycache__/geometry.cpython-39.pyc differ diff --git a/vggt/vggt/utils/__pycache__/load_fn.cpython-310.pyc b/vggt/vggt/utils/__pycache__/load_fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0c66be0a152540615dcd3de1868c146d167e458 Binary files /dev/null and b/vggt/vggt/utils/__pycache__/load_fn.cpython-310.pyc differ diff --git a/vggt/vggt/utils/__pycache__/load_fn.cpython-39.pyc b/vggt/vggt/utils/__pycache__/load_fn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5b5d6a5594dd779125eed7ec1f5d30f956f3a1c Binary files /dev/null and b/vggt/vggt/utils/__pycache__/load_fn.cpython-39.pyc differ diff --git a/vggt/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/vggt/vggt/utils/__pycache__/pose_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d803d110b6d92b4e35795719b0c922c8609dc2d8 Binary files /dev/null and b/vggt/vggt/utils/__pycache__/pose_enc.cpython-310.pyc differ diff --git a/vggt/vggt/utils/__pycache__/pose_enc.cpython-39.pyc b/vggt/vggt/utils/__pycache__/pose_enc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f60d48dd66d011b1a43f9ac68e1a431832a9c585 Binary files /dev/null and b/vggt/vggt/utils/__pycache__/pose_enc.cpython-39.pyc differ diff --git a/vggt/vggt/utils/__pycache__/rotation.cpython-310.pyc b/vggt/vggt/utils/__pycache__/rotation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e8fbf16b853fc4fb2aaee277a639fda2cd22eb Binary files /dev/null and b/vggt/vggt/utils/__pycache__/rotation.cpython-310.pyc differ diff --git a/vggt/vggt/utils/__pycache__/rotation.cpython-39.pyc b/vggt/vggt/utils/__pycache__/rotation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a683ccf44f4f4b81c088167f51fd86f75643ae59 Binary files /dev/null and b/vggt/vggt/utils/__pycache__/rotation.cpython-39.pyc differ diff --git a/vggt/vggt/utils/geometry.py b/vggt/vggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f555516dbc8a7dd8c7b15e6fbc928a5bfe8f740b --- /dev/null +++ b/vggt/vggt/utils/geometry.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + + +from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix + + +# TODO: this code can be further cleaned up + + +def project_world_points_to_camera_points_batch(world_points, cam_extrinsics): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + world_points (torch.Tensor): 3D points of shape BxSxHxWx3. + cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4. + Returns: + """ + # TODO: merge this into project_world_points_to_cam + + # device = world_points.device + # with torch.autocast(device_type=device.type, enabled=False): + ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1) + world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4) + + # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4) + extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3) + + # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1) + world_points_h_exp = world_points_h.unsqueeze(-1) + + # Now perform the matrix multiplication + # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1) + camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1) + + return camera_points + + + +def project_world_points_to_cam( + world_points, + cam_extrinsics, + cam_intrinsics=None, + distortion_params=None, + default=0, + only_points_cam=False, +): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + world_points (torch.Tensor): 3D points of shape Px3. + cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. + Returns: + torch.Tensor: Transformed 2D points of shape BxNx2. + """ + device = world_points.device + # with torch.autocast(device_type=device.type, dtype=torch.double): + with torch.autocast(device_type=device.type, enabled=False): + N = world_points.shape[0] # Number of points + B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras + world_points_homogeneous = torch.cat( + [world_points, torch.ones_like(world_points[..., 0:1])], dim=1 + ) # Nx4 + # Reshape for batch processing + world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand( + B, -1, -1 + ) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + cam_points = torch.bmm( + cam_extrinsics, world_points_homogeneous.transpose(-1, -2) + ) + + if only_points_cam: + return None, cam_points + + # Step 2: Apply intrinsic parameters and (optional) distortion + image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default) + + return image_points, cam_points + + + +def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalized device coordinates (NDC) + cam_points = cam_points / cam_points[:, 2:3, :] + ndc_xy = cam_points[:, :2, :] + + # Apply distortion if distortion_params are provided + if distortion_params is not None: + x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1]) + distorted_xy = torch.stack([x_distorted, y_distorted], dim=1) + else: + distorted_xy = ndc_xy + + # Prepare cam_points for batch matrix multiplication + cam_coords_homo = torch.cat( + (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1 + ) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN + + # Extract x and y coordinates + pixel_coords = pixel_coords[:, :2, :] # Bx2xN + + # Replace NaNs with default value + pixel_coords = torch.nan_to_num(pixel_coords, nan=default) + + return pixel_coords.transpose(1, 2) # BxNx2 + + + + +def cam_from_img(pred_tracks, intrinsics, extra_params=None): + """ + Normalize predicted tracks based on camera intrinsics. + Args: + intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. + pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + Returns: + torch.Tensor: Normalized tracks tensor. + """ + + # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here + # otherwise we can use something like + # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2)) + + principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) + focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) + tracks_normalized = (pred_tracks - principal_point) / focal_length + + if extra_params is not None: + # Apply iterative undistortion + try: + tracks_normalized = iterative_undistortion( + extra_params, tracks_normalized + ) + except: + tracks_normalized = single_undistortion( + extra_params, tracks_normalized + ) + + return tracks_normalized \ No newline at end of file diff --git a/vggt/vggt/utils/helper.py b/vggt/vggt/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7b019189c85ff86645a4cf3756632aa8d4500649 --- /dev/null +++ b/vggt/vggt/utils/helper.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + + +def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: + """ + If mask has more than max_trues True values, + randomly keep only max_trues of them and set the rest to False. + """ + # 1D positions of all True entries + true_indices = np.flatnonzero(mask) # shape = (N_true,) + + # if already within budget, return as-is + if true_indices.size <= max_trues: + return mask + + # randomly pick which True positions to keep + sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) + + # build new flat mask: True only at sampled positions + limited_flat_mask = np.zeros(mask.size, dtype=bool) + limited_flat_mask[sampled_indices] = True + + # restore original shape + return limited_flat_mask.reshape(mask.shape) + + +def create_pixel_coordinate_grid(num_frames, height, width): + """ + Creates a grid of pixel coordinates and frame indices for all frames. + Returns: + tuple: A tuple containing: + - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) + with x, y coordinates and frame indices + - y_coords (numpy.ndarray): Array of y coordinates for all frames + - x_coords (numpy.ndarray): Array of x coordinates for all frames + - f_coords (numpy.ndarray): Array of frame indices for all frames + """ + # Create coordinate grids for a single frame + y_grid, x_grid = np.indices((height, width), dtype=np.float32) + x_grid = x_grid[np.newaxis, :, :] + y_grid = y_grid[np.newaxis, :, :] + + # Broadcast to all frames + x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) + y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) + + # Create frame indices and broadcast + f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] + f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) + + # Stack coordinates and frame indices + points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) + + return points_xyf diff --git a/vggt/vggt/utils/load_fn.py b/vggt/vggt/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..385d531a9252525a52f08d1ad40463f6878b83bf --- /dev/null +++ b/vggt/vggt/utils/load_fn.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF +import numpy as np + + +def load_and_preprocess_images_square(image_path_list, target_size=1024): + """ + Load and preprocess images by center padding to square and resizing to target size. + Also returns the position information of original pixels after transformation. + + Args: + image_path_list (list): List of paths to image files + target_size (int, optional): Target size for both width and height. Defaults to 518. + + Returns: + tuple: ( + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), + torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image + ) + + Raises: + ValueError: If the input list is empty + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + images = [] + original_coords = [] # Renamed from position_info to be more descriptive + to_tensor = TF.ToTensor() + + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background + if img.mode == "RGBA": + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + img = Image.alpha_composite(background, img) + + # Convert to RGB + img = img.convert("RGB") + + # Get original dimensions + width, height = img.size + + # Make the image square by padding the shorter dimension + max_dim = max(width, height) + + # Calculate padding + left = (max_dim - width) // 2 + top = (max_dim - height) // 2 + + # Calculate scale factor for resizing + scale = target_size / max_dim + + # Calculate final coordinates of original image in target space + x1 = left * scale + y1 = top * scale + x2 = (left + width) * scale + y2 = (top + height) * scale + + # Store original image coordinates and scale + original_coords.append(np.array([x1, y1, x2, y2, width, height])) + + # Create a new black square image and paste original + square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) + square_img.paste(img, (left, top)) + + # Resize to target size + square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) + + # Convert to tensor + img_tensor = to_tensor(square_img) + images.append(img_tensor) + + # Stack all images + images = torch.stack(images) + original_coords = torch.from_numpy(np.array(original_coords)).float() + + # Add additional dimension if single image to ensure correct shape + if len(image_path_list) == 1: + if images.dim() == 3: + images = images.unsqueeze(0) + original_coords = original_coords.unsqueeze(0) + + return images, original_coords + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for image_path in image_path_list: + # Open image + + + # img = Image.open(image_path) + + # # If there's an alpha channel, blend onto white background: + # if img.mode == "RGBA": + # # Create white background + # background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # # Alpha composite onto the white background + # img = Image.alpha_composite(background, img) + + # # Now convert to "RGB" (this step assigns white for transparent areas) + # img = img.convert("RGB") + + img = image_path.convert("RGB") # Use the passed image directly + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/vggt/vggt/utils/pose_enc.py b/vggt/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3b964330af0e62f4d36d332317ae00cb99b3a9 --- /dev/null +++ b/vggt/vggt/utils/pose_enc.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/vggt/vggt/utils/rotation.py b/vggt/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..f972afd8414c82fa1e9ed231725fd3f9f6ebde77 --- /dev/null +++ b/vggt/vggt/utils/rotation.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/vggt/vggt/utils/visual_track.py b/vggt/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154 --- /dev/null +++ b/vggt/vggt/utils/visual_track.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")