JSYuuu commited on
Commit
84163e3
·
verified ·
1 Parent(s): aefca3d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ license: apache-2.0
4
+ pipeline_tag: any-to-any
5
+ ---
6
+
7
+ # ThinkGen: Generalized Thinking for Visual Generation
8
+
9
+ ThinkGen is the first think-driven visual generation framework that explicitly leverages Multimodal Large Language Models' (MLLMs) Chain-of-Thought (CoT) reasoning in various generation scenarios. ThinkGen employs a decoupled architecture comprising a pretrained MLLM and a Diffusion Transformer (DiT), wherein the MLLM generates tailored instructions based on user intent, and the DiT produces high-quality images guided by these instructions.
10
+
11
+ - **Paper:** [ThinkGen: Generalized Thinking for Visual Generation](https://huggingface.co/papers/2512.23568)
12
+ - **Code:** [GitHub Repository](https://github.com/jiaosiyuu/ThinkGen)
13
+
14
+ **Authors**: Siyu Jiao, Yiheng Lin, Yujie Zhong, Qi She, Wei Zhou, Xiaohan Lan, Zilong Huang, Fei Yu, Yingchen Yu, Yunqing Zhao, Yao Zhao, Yunchao Wei.
15
+
16
+ ## 🚀 Quick Start
17
+
18
+ ### 🛠️ Environment Setup
19
+
20
+ ```bash
21
+ # 1. Clone the repo
22
+ git clone https://github.com/jiaosiyuu/ThinkGen.git
23
+ cd ThinkGen
24
+
25
+ # 2. (Optional) Create a clean Python environment
26
+ conda create -n thinkgen python=3.11
27
+ conda activate thinkgen
28
+
29
+ # 3. Install dependencies
30
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
31
+ pip install -r req.txt
32
+
33
+ # ThinkGen runs even without flash-attn, though we recommend install it for best performance.
34
+ pip install --no-cache-dir flash-attn==2.7.4.post1 --no-build-isolation
35
+ ```
36
+
37
+ ### 💻 Sample Usage
38
+
39
+ ```python
40
+ from ThinkGen.model import ThinkGen_Chat
41
+ import os
42
+
43
+ model_path = "JSYuuu/ThinkGen-stage3"
44
+
45
+ chat_model = ThinkGen_Chat(
46
+ model_path=model_path,
47
+ dtype='bf16',
48
+ height=1024,
49
+ width=1024
50
+ )
51
+
52
+ # 1. Image Generation
53
+ messages = [
54
+ {"type": "text", "value": "A young woman wearing a straw hat, standing in a golden wheat field."}
55
+ ]
56
+ results = chat_model.generate_image(messages)
57
+ results.images[0].save("result.png")
58
+
59
+ # 2. Image Generation with Thinking (CoT)
60
+ # This enables the MLLM's CoT reasoning for generation
61
+ results_think = chat_model.generate_image(messages, think=True)
62
+ print(f"cot & rewrite prompt:
63
+ {results_think.prompt_cot}")
64
+ results_think.images[0].save("result_think.png")
65
+
66
+ # 3. Image Understanding
67
+ messages_und = [
68
+ {"type": "image", "value": "images/teaser.png"},
69
+ {"type": "text", "value": "Describe this image"}
70
+ ]
71
+ response = chat_model.generate_text(messages_und)
72
+ print(response)
73
+ ```
74
+
75
+ ## Acknowledgments
76
+ This work builds upon the following great open-source projects:
77
+ * **OmniGen2:** https://github.com/VectorSpaceLab/OmniGen2
78
+ * **Qwen3VL:** https://github.com/QwenLM/Qwen3-VL
79
+ * **EasyR1:** https://github.com/hiyouga/EasyR1
80
+ * **Flow-GRPO:** https://github.com/yifan123/flow_grpo
81
+
82
+ ## Citation
83
+ ```bibtex
84
+ @article{jiao2025thinkgen,
85
+ title={ThinkGen: Generalized Thinking for Visual Generation},
86
+ author={Jiao, Siyu and Lin, Yiheng and Zhong, Yujie and She, Qi and Zhou, Wei and Lan, Xiaohan and Huang, Zilong and Yu, Fei and Yu, Yingchen and Zhao, Yunqing and Zhao, Yao and Wei, Yunchao},
87
+ journal={arXiv preprint arXiv:2512.23568},
88
+ year={2025}
89
+ }
90
+ ```
91
+
92
+ ## License
93
+ This work is licensed under the Apache 2.0 license.
mllm/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3VLForConditionalGeneration"
4
+ ],
5
+ "dtype": "float32",
6
+ "eos_token_id": 151645,
7
+ "image_token_id": 151655,
8
+ "model_type": "qwen3_vl",
9
+ "pad_token_id": 151643,
10
+ "text_config": {
11
+ "attention_bias": false,
12
+ "attention_dropout": 0.0,
13
+ "bos_token_id": 151643,
14
+ "dtype": "float32",
15
+ "eos_token_id": 151645,
16
+ "head_dim": 128,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 4096,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 12288,
21
+ "max_position_embeddings": 262144,
22
+ "model_type": "qwen3_vl_text",
23
+ "num_attention_heads": 32,
24
+ "num_hidden_layers": 36,
25
+ "num_key_value_heads": 8,
26
+ "rms_norm_eps": 1e-06,
27
+ "rope_scaling": {
28
+ "mrope_interleaved": true,
29
+ "mrope_section": [
30
+ 24,
31
+ 20,
32
+ 20
33
+ ],
34
+ "rope_type": "default"
35
+ },
36
+ "rope_theta": 5000000,
37
+ "use_cache": true,
38
+ "vocab_size": 151936
39
+ },
40
+ "tie_word_embeddings": false,
41
+ "transformers_version": "4.57.1",
42
+ "video_token_id": 151656,
43
+ "vision_config": {
44
+ "deepstack_visual_indexes": [
45
+ 8,
46
+ 16,
47
+ 24
48
+ ],
49
+ "depth": 27,
50
+ "dtype": "float32",
51
+ "hidden_act": "gelu_pytorch_tanh",
52
+ "hidden_size": 1152,
53
+ "in_channels": 3,
54
+ "initializer_range": 0.02,
55
+ "intermediate_size": 4304,
56
+ "model_type": "qwen3_vl",
57
+ "num_heads": 16,
58
+ "num_position_embeddings": 2304,
59
+ "out_hidden_size": 4096,
60
+ "patch_size": 16,
61
+ "spatial_merge_size": 2,
62
+ "temporal_patch_size": 2
63
+ },
64
+ "vision_end_token_id": 151653,
65
+ "vision_start_token_id": 151652
66
+ }
mllm/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "pad_token_id": 151643,
6
+ "transformers_version": "4.57.1"
7
+ }
mllm/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
mllm/model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a2ad68570fe788a0bbf03ed07c1d32cea83884fa328c587a5ef97d797cf2e91
3
+ size 4902275944
mllm/model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5192d62fbc367626d743551b6a91461bfae305db6cf71eaeade89598d21e4f7d
3
+ size 4915962496
mllm/model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4931c8b5e85666292daa65153726b20030f54d81d8f51d732a367b9e051e5fbc
3
+ size 4999831048
mllm/model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef14a367ae345934318ea7b06bf404b97c3a312418f40dd5f92d8296af96de13
3
+ size 2716270024
mllm/model.safetensors.index.json ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 17534247392
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "model.language_model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "model.language_model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.language_model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.language_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.language_model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.language_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "model.language_model.layers.0.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
14
+ "model.language_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.language_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.language_model.layers.0.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
17
+ "model.language_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.language_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.language_model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
20
+ "model.language_model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
21
+ "model.language_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.language_model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.language_model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
24
+ "model.language_model.layers.1.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
25
+ "model.language_model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.language_model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.language_model.layers.1.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
28
+ "model.language_model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
29
+ "model.language_model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.language_model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
31
+ "model.language_model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
32
+ "model.language_model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
33
+ "model.language_model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.language_model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
35
+ "model.language_model.layers.10.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
36
+ "model.language_model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
37
+ "model.language_model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
38
+ "model.language_model.layers.10.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
39
+ "model.language_model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.language_model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
41
+ "model.language_model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
42
+ "model.language_model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.language_model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.language_model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.language_model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "model.language_model.layers.11.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
47
+ "model.language_model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
48
+ "model.language_model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
49
+ "model.language_model.layers.11.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
50
+ "model.language_model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
51
+ "model.language_model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.language_model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
53
+ "model.language_model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.language_model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
55
+ "model.language_model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
56
+ "model.language_model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
57
+ "model.language_model.layers.12.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
58
+ "model.language_model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.language_model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
60
+ "model.language_model.layers.12.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
61
+ "model.language_model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
62
+ "model.language_model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.language_model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
64
+ "model.language_model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
65
+ "model.language_model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.language_model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.language_model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
68
+ "model.language_model.layers.13.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
69
+ "model.language_model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.language_model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.language_model.layers.13.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
72
+ "model.language_model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
73
+ "model.language_model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
74
+ "model.language_model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
75
+ "model.language_model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
76
+ "model.language_model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
77
+ "model.language_model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.language_model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
79
+ "model.language_model.layers.14.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
80
+ "model.language_model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
81
+ "model.language_model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.language_model.layers.14.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
83
+ "model.language_model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.language_model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
85
+ "model.language_model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
86
+ "model.language_model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.language_model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.language_model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.language_model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
90
+ "model.language_model.layers.15.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
91
+ "model.language_model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
92
+ "model.language_model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
93
+ "model.language_model.layers.15.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
94
+ "model.language_model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.language_model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.language_model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
97
+ "model.language_model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
98
+ "model.language_model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.language_model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.language_model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
101
+ "model.language_model.layers.16.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
102
+ "model.language_model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
103
+ "model.language_model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.language_model.layers.16.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
105
+ "model.language_model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.language_model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.language_model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
108
+ "model.language_model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
109
+ "model.language_model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
110
+ "model.language_model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.language_model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
112
+ "model.language_model.layers.17.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
113
+ "model.language_model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
114
+ "model.language_model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
115
+ "model.language_model.layers.17.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
116
+ "model.language_model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
117
+ "model.language_model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.language_model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
119
+ "model.language_model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.language_model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
121
+ "model.language_model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
122
+ "model.language_model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
123
+ "model.language_model.layers.18.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
124
+ "model.language_model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
125
+ "model.language_model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
126
+ "model.language_model.layers.18.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
127
+ "model.language_model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
128
+ "model.language_model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
129
+ "model.language_model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
130
+ "model.language_model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.language_model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
132
+ "model.language_model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
133
+ "model.language_model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
134
+ "model.language_model.layers.19.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
135
+ "model.language_model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.language_model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
137
+ "model.language_model.layers.19.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
138
+ "model.language_model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
139
+ "model.language_model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
140
+ "model.language_model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
141
+ "model.language_model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
142
+ "model.language_model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
143
+ "model.language_model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
144
+ "model.language_model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
145
+ "model.language_model.layers.2.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
146
+ "model.language_model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
147
+ "model.language_model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
148
+ "model.language_model.layers.2.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
149
+ "model.language_model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
150
+ "model.language_model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
151
+ "model.language_model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
152
+ "model.language_model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
153
+ "model.language_model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
154
+ "model.language_model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
155
+ "model.language_model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
156
+ "model.language_model.layers.20.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
157
+ "model.language_model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
158
+ "model.language_model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
159
+ "model.language_model.layers.20.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
160
+ "model.language_model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
161
+ "model.language_model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
162
+ "model.language_model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
163
+ "model.language_model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
164
+ "model.language_model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
165
+ "model.language_model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
166
+ "model.language_model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
167
+ "model.language_model.layers.21.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
168
+ "model.language_model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
169
+ "model.language_model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
170
+ "model.language_model.layers.21.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
171
+ "model.language_model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
172
+ "model.language_model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
173
+ "model.language_model.layers.22.input_layernorm.weight": "model-00002-of-00004.safetensors",
174
+ "model.language_model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
175
+ "model.language_model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.language_model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
177
+ "model.language_model.layers.22.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
178
+ "model.language_model.layers.22.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
179
+ "model.language_model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
180
+ "model.language_model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
181
+ "model.language_model.layers.22.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
182
+ "model.language_model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
183
+ "model.language_model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
184
+ "model.language_model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
185
+ "model.language_model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
186
+ "model.language_model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
187
+ "model.language_model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.language_model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
189
+ "model.language_model.layers.23.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
190
+ "model.language_model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
191
+ "model.language_model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.language_model.layers.23.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
193
+ "model.language_model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
194
+ "model.language_model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.language_model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
196
+ "model.language_model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
197
+ "model.language_model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.language_model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
199
+ "model.language_model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
200
+ "model.language_model.layers.24.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
201
+ "model.language_model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.language_model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.language_model.layers.24.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
204
+ "model.language_model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
205
+ "model.language_model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
206
+ "model.language_model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
207
+ "model.language_model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.language_model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
209
+ "model.language_model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.language_model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
211
+ "model.language_model.layers.25.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
212
+ "model.language_model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
213
+ "model.language_model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.language_model.layers.25.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
215
+ "model.language_model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
216
+ "model.language_model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
217
+ "model.language_model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
218
+ "model.language_model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
219
+ "model.language_model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.language_model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
221
+ "model.language_model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
222
+ "model.language_model.layers.26.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
223
+ "model.language_model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
224
+ "model.language_model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
225
+ "model.language_model.layers.26.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
226
+ "model.language_model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.language_model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
228
+ "model.language_model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
229
+ "model.language_model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
230
+ "model.language_model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
231
+ "model.language_model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.language_model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
233
+ "model.language_model.layers.27.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
234
+ "model.language_model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
235
+ "model.language_model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.language_model.layers.27.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
237
+ "model.language_model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
238
+ "model.language_model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.language_model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
240
+ "model.language_model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
241
+ "model.language_model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
242
+ "model.language_model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.language_model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
244
+ "model.language_model.layers.28.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
245
+ "model.language_model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.language_model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
247
+ "model.language_model.layers.28.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
248
+ "model.language_model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
249
+ "model.language_model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.language_model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
251
+ "model.language_model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
252
+ "model.language_model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
253
+ "model.language_model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
254
+ "model.language_model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
255
+ "model.language_model.layers.29.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
256
+ "model.language_model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
257
+ "model.language_model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
258
+ "model.language_model.layers.29.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
259
+ "model.language_model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.language_model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
261
+ "model.language_model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
262
+ "model.language_model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
263
+ "model.language_model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
264
+ "model.language_model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
265
+ "model.language_model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
266
+ "model.language_model.layers.3.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
267
+ "model.language_model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
268
+ "model.language_model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
269
+ "model.language_model.layers.3.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
270
+ "model.language_model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
271
+ "model.language_model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
272
+ "model.language_model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
273
+ "model.language_model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
274
+ "model.language_model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
275
+ "model.language_model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
276
+ "model.language_model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
277
+ "model.language_model.layers.30.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
278
+ "model.language_model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
279
+ "model.language_model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
280
+ "model.language_model.layers.30.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
281
+ "model.language_model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
282
+ "model.language_model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
283
+ "model.language_model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
284
+ "model.language_model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
285
+ "model.language_model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
286
+ "model.language_model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
287
+ "model.language_model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
288
+ "model.language_model.layers.31.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
289
+ "model.language_model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
290
+ "model.language_model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
291
+ "model.language_model.layers.31.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
292
+ "model.language_model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
293
+ "model.language_model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
294
+ "model.language_model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
295
+ "model.language_model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
296
+ "model.language_model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
297
+ "model.language_model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
298
+ "model.language_model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
299
+ "model.language_model.layers.32.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
300
+ "model.language_model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
301
+ "model.language_model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
302
+ "model.language_model.layers.32.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
303
+ "model.language_model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
304
+ "model.language_model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
305
+ "model.language_model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
306
+ "model.language_model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
307
+ "model.language_model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
308
+ "model.language_model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
309
+ "model.language_model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
310
+ "model.language_model.layers.33.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
311
+ "model.language_model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
312
+ "model.language_model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
313
+ "model.language_model.layers.33.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
314
+ "model.language_model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
315
+ "model.language_model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
316
+ "model.language_model.layers.34.input_layernorm.weight": "model-00003-of-00004.safetensors",
317
+ "model.language_model.layers.34.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
318
+ "model.language_model.layers.34.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
319
+ "model.language_model.layers.34.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
320
+ "model.language_model.layers.34.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
321
+ "model.language_model.layers.34.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
322
+ "model.language_model.layers.34.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
323
+ "model.language_model.layers.34.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
324
+ "model.language_model.layers.34.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
325
+ "model.language_model.layers.34.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
326
+ "model.language_model.layers.34.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
327
+ "model.language_model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
328
+ "model.language_model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
329
+ "model.language_model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
330
+ "model.language_model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
331
+ "model.language_model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
332
+ "model.language_model.layers.35.self_attn.k_norm.weight": "model-00004-of-00004.safetensors",
333
+ "model.language_model.layers.35.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
334
+ "model.language_model.layers.35.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
335
+ "model.language_model.layers.35.self_attn.q_norm.weight": "model-00004-of-00004.safetensors",
336
+ "model.language_model.layers.35.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
337
+ "model.language_model.layers.35.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
338
+ "model.language_model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
339
+ "model.language_model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
340
+ "model.language_model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
341
+ "model.language_model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
342
+ "model.language_model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
343
+ "model.language_model.layers.4.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
344
+ "model.language_model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
345
+ "model.language_model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
346
+ "model.language_model.layers.4.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
347
+ "model.language_model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
348
+ "model.language_model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
349
+ "model.language_model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
350
+ "model.language_model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
351
+ "model.language_model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
352
+ "model.language_model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
353
+ "model.language_model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
354
+ "model.language_model.layers.5.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
355
+ "model.language_model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
356
+ "model.language_model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
357
+ "model.language_model.layers.5.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
358
+ "model.language_model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
359
+ "model.language_model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
360
+ "model.language_model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
361
+ "model.language_model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
362
+ "model.language_model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
363
+ "model.language_model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
364
+ "model.language_model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
365
+ "model.language_model.layers.6.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
366
+ "model.language_model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
367
+ "model.language_model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
368
+ "model.language_model.layers.6.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
369
+ "model.language_model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
370
+ "model.language_model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.language_model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
372
+ "model.language_model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
373
+ "model.language_model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
374
+ "model.language_model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
375
+ "model.language_model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
376
+ "model.language_model.layers.7.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
377
+ "model.language_model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
378
+ "model.language_model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
379
+ "model.language_model.layers.7.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
380
+ "model.language_model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
381
+ "model.language_model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
382
+ "model.language_model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
383
+ "model.language_model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
384
+ "model.language_model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
385
+ "model.language_model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
386
+ "model.language_model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
387
+ "model.language_model.layers.8.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
388
+ "model.language_model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
389
+ "model.language_model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
390
+ "model.language_model.layers.8.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
391
+ "model.language_model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
392
+ "model.language_model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
393
+ "model.language_model.layers.9.input_layernorm.weight": "model-00001-of-00004.safetensors",
394
+ "model.language_model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
395
+ "model.language_model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
396
+ "model.language_model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
397
+ "model.language_model.layers.9.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
398
+ "model.language_model.layers.9.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
399
+ "model.language_model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
400
+ "model.language_model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
401
+ "model.language_model.layers.9.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
402
+ "model.language_model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
403
+ "model.language_model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.language_model.norm.weight": "model-00004-of-00004.safetensors",
405
+ "model.visual.blocks.0.attn.proj.bias": "model-00004-of-00004.safetensors",
406
+ "model.visual.blocks.0.attn.proj.weight": "model-00004-of-00004.safetensors",
407
+ "model.visual.blocks.0.attn.qkv.bias": "model-00004-of-00004.safetensors",
408
+ "model.visual.blocks.0.attn.qkv.weight": "model-00004-of-00004.safetensors",
409
+ "model.visual.blocks.0.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
410
+ "model.visual.blocks.0.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
411
+ "model.visual.blocks.0.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
412
+ "model.visual.blocks.0.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
413
+ "model.visual.blocks.0.norm1.bias": "model-00004-of-00004.safetensors",
414
+ "model.visual.blocks.0.norm1.weight": "model-00004-of-00004.safetensors",
415
+ "model.visual.blocks.0.norm2.bias": "model-00004-of-00004.safetensors",
416
+ "model.visual.blocks.0.norm2.weight": "model-00004-of-00004.safetensors",
417
+ "model.visual.blocks.1.attn.proj.bias": "model-00004-of-00004.safetensors",
418
+ "model.visual.blocks.1.attn.proj.weight": "model-00004-of-00004.safetensors",
419
+ "model.visual.blocks.1.attn.qkv.bias": "model-00004-of-00004.safetensors",
420
+ "model.visual.blocks.1.attn.qkv.weight": "model-00004-of-00004.safetensors",
421
+ "model.visual.blocks.1.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
422
+ "model.visual.blocks.1.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
423
+ "model.visual.blocks.1.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
424
+ "model.visual.blocks.1.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
425
+ "model.visual.blocks.1.norm1.bias": "model-00004-of-00004.safetensors",
426
+ "model.visual.blocks.1.norm1.weight": "model-00004-of-00004.safetensors",
427
+ "model.visual.blocks.1.norm2.bias": "model-00004-of-00004.safetensors",
428
+ "model.visual.blocks.1.norm2.weight": "model-00004-of-00004.safetensors",
429
+ "model.visual.blocks.10.attn.proj.bias": "model-00004-of-00004.safetensors",
430
+ "model.visual.blocks.10.attn.proj.weight": "model-00004-of-00004.safetensors",
431
+ "model.visual.blocks.10.attn.qkv.bias": "model-00004-of-00004.safetensors",
432
+ "model.visual.blocks.10.attn.qkv.weight": "model-00004-of-00004.safetensors",
433
+ "model.visual.blocks.10.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
434
+ "model.visual.blocks.10.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
435
+ "model.visual.blocks.10.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
436
+ "model.visual.blocks.10.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
437
+ "model.visual.blocks.10.norm1.bias": "model-00004-of-00004.safetensors",
438
+ "model.visual.blocks.10.norm1.weight": "model-00004-of-00004.safetensors",
439
+ "model.visual.blocks.10.norm2.bias": "model-00004-of-00004.safetensors",
440
+ "model.visual.blocks.10.norm2.weight": "model-00004-of-00004.safetensors",
441
+ "model.visual.blocks.11.attn.proj.bias": "model-00004-of-00004.safetensors",
442
+ "model.visual.blocks.11.attn.proj.weight": "model-00004-of-00004.safetensors",
443
+ "model.visual.blocks.11.attn.qkv.bias": "model-00004-of-00004.safetensors",
444
+ "model.visual.blocks.11.attn.qkv.weight": "model-00004-of-00004.safetensors",
445
+ "model.visual.blocks.11.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
446
+ "model.visual.blocks.11.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
447
+ "model.visual.blocks.11.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
448
+ "model.visual.blocks.11.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
449
+ "model.visual.blocks.11.norm1.bias": "model-00004-of-00004.safetensors",
450
+ "model.visual.blocks.11.norm1.weight": "model-00004-of-00004.safetensors",
451
+ "model.visual.blocks.11.norm2.bias": "model-00004-of-00004.safetensors",
452
+ "model.visual.blocks.11.norm2.weight": "model-00004-of-00004.safetensors",
453
+ "model.visual.blocks.12.attn.proj.bias": "model-00004-of-00004.safetensors",
454
+ "model.visual.blocks.12.attn.proj.weight": "model-00004-of-00004.safetensors",
455
+ "model.visual.blocks.12.attn.qkv.bias": "model-00004-of-00004.safetensors",
456
+ "model.visual.blocks.12.attn.qkv.weight": "model-00004-of-00004.safetensors",
457
+ "model.visual.blocks.12.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
458
+ "model.visual.blocks.12.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
459
+ "model.visual.blocks.12.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
460
+ "model.visual.blocks.12.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
461
+ "model.visual.blocks.12.norm1.bias": "model-00004-of-00004.safetensors",
462
+ "model.visual.blocks.12.norm1.weight": "model-00004-of-00004.safetensors",
463
+ "model.visual.blocks.12.norm2.bias": "model-00004-of-00004.safetensors",
464
+ "model.visual.blocks.12.norm2.weight": "model-00004-of-00004.safetensors",
465
+ "model.visual.blocks.13.attn.proj.bias": "model-00004-of-00004.safetensors",
466
+ "model.visual.blocks.13.attn.proj.weight": "model-00004-of-00004.safetensors",
467
+ "model.visual.blocks.13.attn.qkv.bias": "model-00004-of-00004.safetensors",
468
+ "model.visual.blocks.13.attn.qkv.weight": "model-00004-of-00004.safetensors",
469
+ "model.visual.blocks.13.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
470
+ "model.visual.blocks.13.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
471
+ "model.visual.blocks.13.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
472
+ "model.visual.blocks.13.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
473
+ "model.visual.blocks.13.norm1.bias": "model-00004-of-00004.safetensors",
474
+ "model.visual.blocks.13.norm1.weight": "model-00004-of-00004.safetensors",
475
+ "model.visual.blocks.13.norm2.bias": "model-00004-of-00004.safetensors",
476
+ "model.visual.blocks.13.norm2.weight": "model-00004-of-00004.safetensors",
477
+ "model.visual.blocks.14.attn.proj.bias": "model-00004-of-00004.safetensors",
478
+ "model.visual.blocks.14.attn.proj.weight": "model-00004-of-00004.safetensors",
479
+ "model.visual.blocks.14.attn.qkv.bias": "model-00004-of-00004.safetensors",
480
+ "model.visual.blocks.14.attn.qkv.weight": "model-00004-of-00004.safetensors",
481
+ "model.visual.blocks.14.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
482
+ "model.visual.blocks.14.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
483
+ "model.visual.blocks.14.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
484
+ "model.visual.blocks.14.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
485
+ "model.visual.blocks.14.norm1.bias": "model-00004-of-00004.safetensors",
486
+ "model.visual.blocks.14.norm1.weight": "model-00004-of-00004.safetensors",
487
+ "model.visual.blocks.14.norm2.bias": "model-00004-of-00004.safetensors",
488
+ "model.visual.blocks.14.norm2.weight": "model-00004-of-00004.safetensors",
489
+ "model.visual.blocks.15.attn.proj.bias": "model-00004-of-00004.safetensors",
490
+ "model.visual.blocks.15.attn.proj.weight": "model-00004-of-00004.safetensors",
491
+ "model.visual.blocks.15.attn.qkv.bias": "model-00004-of-00004.safetensors",
492
+ "model.visual.blocks.15.attn.qkv.weight": "model-00004-of-00004.safetensors",
493
+ "model.visual.blocks.15.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
494
+ "model.visual.blocks.15.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
495
+ "model.visual.blocks.15.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
496
+ "model.visual.blocks.15.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
497
+ "model.visual.blocks.15.norm1.bias": "model-00004-of-00004.safetensors",
498
+ "model.visual.blocks.15.norm1.weight": "model-00004-of-00004.safetensors",
499
+ "model.visual.blocks.15.norm2.bias": "model-00004-of-00004.safetensors",
500
+ "model.visual.blocks.15.norm2.weight": "model-00004-of-00004.safetensors",
501
+ "model.visual.blocks.16.attn.proj.bias": "model-00004-of-00004.safetensors",
502
+ "model.visual.blocks.16.attn.proj.weight": "model-00004-of-00004.safetensors",
503
+ "model.visual.blocks.16.attn.qkv.bias": "model-00004-of-00004.safetensors",
504
+ "model.visual.blocks.16.attn.qkv.weight": "model-00004-of-00004.safetensors",
505
+ "model.visual.blocks.16.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
506
+ "model.visual.blocks.16.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
507
+ "model.visual.blocks.16.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
508
+ "model.visual.blocks.16.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
509
+ "model.visual.blocks.16.norm1.bias": "model-00004-of-00004.safetensors",
510
+ "model.visual.blocks.16.norm1.weight": "model-00004-of-00004.safetensors",
511
+ "model.visual.blocks.16.norm2.bias": "model-00004-of-00004.safetensors",
512
+ "model.visual.blocks.16.norm2.weight": "model-00004-of-00004.safetensors",
513
+ "model.visual.blocks.17.attn.proj.bias": "model-00004-of-00004.safetensors",
514
+ "model.visual.blocks.17.attn.proj.weight": "model-00004-of-00004.safetensors",
515
+ "model.visual.blocks.17.attn.qkv.bias": "model-00004-of-00004.safetensors",
516
+ "model.visual.blocks.17.attn.qkv.weight": "model-00004-of-00004.safetensors",
517
+ "model.visual.blocks.17.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
518
+ "model.visual.blocks.17.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
519
+ "model.visual.blocks.17.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
520
+ "model.visual.blocks.17.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
521
+ "model.visual.blocks.17.norm1.bias": "model-00004-of-00004.safetensors",
522
+ "model.visual.blocks.17.norm1.weight": "model-00004-of-00004.safetensors",
523
+ "model.visual.blocks.17.norm2.bias": "model-00004-of-00004.safetensors",
524
+ "model.visual.blocks.17.norm2.weight": "model-00004-of-00004.safetensors",
525
+ "model.visual.blocks.18.attn.proj.bias": "model-00004-of-00004.safetensors",
526
+ "model.visual.blocks.18.attn.proj.weight": "model-00004-of-00004.safetensors",
527
+ "model.visual.blocks.18.attn.qkv.bias": "model-00004-of-00004.safetensors",
528
+ "model.visual.blocks.18.attn.qkv.weight": "model-00004-of-00004.safetensors",
529
+ "model.visual.blocks.18.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
530
+ "model.visual.blocks.18.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
531
+ "model.visual.blocks.18.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
532
+ "model.visual.blocks.18.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
533
+ "model.visual.blocks.18.norm1.bias": "model-00004-of-00004.safetensors",
534
+ "model.visual.blocks.18.norm1.weight": "model-00004-of-00004.safetensors",
535
+ "model.visual.blocks.18.norm2.bias": "model-00004-of-00004.safetensors",
536
+ "model.visual.blocks.18.norm2.weight": "model-00004-of-00004.safetensors",
537
+ "model.visual.blocks.19.attn.proj.bias": "model-00004-of-00004.safetensors",
538
+ "model.visual.blocks.19.attn.proj.weight": "model-00004-of-00004.safetensors",
539
+ "model.visual.blocks.19.attn.qkv.bias": "model-00004-of-00004.safetensors",
540
+ "model.visual.blocks.19.attn.qkv.weight": "model-00004-of-00004.safetensors",
541
+ "model.visual.blocks.19.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
542
+ "model.visual.blocks.19.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
543
+ "model.visual.blocks.19.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
544
+ "model.visual.blocks.19.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
545
+ "model.visual.blocks.19.norm1.bias": "model-00004-of-00004.safetensors",
546
+ "model.visual.blocks.19.norm1.weight": "model-00004-of-00004.safetensors",
547
+ "model.visual.blocks.19.norm2.bias": "model-00004-of-00004.safetensors",
548
+ "model.visual.blocks.19.norm2.weight": "model-00004-of-00004.safetensors",
549
+ "model.visual.blocks.2.attn.proj.bias": "model-00004-of-00004.safetensors",
550
+ "model.visual.blocks.2.attn.proj.weight": "model-00004-of-00004.safetensors",
551
+ "model.visual.blocks.2.attn.qkv.bias": "model-00004-of-00004.safetensors",
552
+ "model.visual.blocks.2.attn.qkv.weight": "model-00004-of-00004.safetensors",
553
+ "model.visual.blocks.2.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
554
+ "model.visual.blocks.2.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
555
+ "model.visual.blocks.2.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
556
+ "model.visual.blocks.2.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
557
+ "model.visual.blocks.2.norm1.bias": "model-00004-of-00004.safetensors",
558
+ "model.visual.blocks.2.norm1.weight": "model-00004-of-00004.safetensors",
559
+ "model.visual.blocks.2.norm2.bias": "model-00004-of-00004.safetensors",
560
+ "model.visual.blocks.2.norm2.weight": "model-00004-of-00004.safetensors",
561
+ "model.visual.blocks.20.attn.proj.bias": "model-00004-of-00004.safetensors",
562
+ "model.visual.blocks.20.attn.proj.weight": "model-00004-of-00004.safetensors",
563
+ "model.visual.blocks.20.attn.qkv.bias": "model-00004-of-00004.safetensors",
564
+ "model.visual.blocks.20.attn.qkv.weight": "model-00004-of-00004.safetensors",
565
+ "model.visual.blocks.20.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
566
+ "model.visual.blocks.20.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
567
+ "model.visual.blocks.20.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
568
+ "model.visual.blocks.20.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
569
+ "model.visual.blocks.20.norm1.bias": "model-00004-of-00004.safetensors",
570
+ "model.visual.blocks.20.norm1.weight": "model-00004-of-00004.safetensors",
571
+ "model.visual.blocks.20.norm2.bias": "model-00004-of-00004.safetensors",
572
+ "model.visual.blocks.20.norm2.weight": "model-00004-of-00004.safetensors",
573
+ "model.visual.blocks.21.attn.proj.bias": "model-00004-of-00004.safetensors",
574
+ "model.visual.blocks.21.attn.proj.weight": "model-00004-of-00004.safetensors",
575
+ "model.visual.blocks.21.attn.qkv.bias": "model-00004-of-00004.safetensors",
576
+ "model.visual.blocks.21.attn.qkv.weight": "model-00004-of-00004.safetensors",
577
+ "model.visual.blocks.21.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
578
+ "model.visual.blocks.21.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
579
+ "model.visual.blocks.21.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
580
+ "model.visual.blocks.21.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
581
+ "model.visual.blocks.21.norm1.bias": "model-00004-of-00004.safetensors",
582
+ "model.visual.blocks.21.norm1.weight": "model-00004-of-00004.safetensors",
583
+ "model.visual.blocks.21.norm2.bias": "model-00004-of-00004.safetensors",
584
+ "model.visual.blocks.21.norm2.weight": "model-00004-of-00004.safetensors",
585
+ "model.visual.blocks.22.attn.proj.bias": "model-00004-of-00004.safetensors",
586
+ "model.visual.blocks.22.attn.proj.weight": "model-00004-of-00004.safetensors",
587
+ "model.visual.blocks.22.attn.qkv.bias": "model-00004-of-00004.safetensors",
588
+ "model.visual.blocks.22.attn.qkv.weight": "model-00004-of-00004.safetensors",
589
+ "model.visual.blocks.22.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
590
+ "model.visual.blocks.22.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
591
+ "model.visual.blocks.22.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
592
+ "model.visual.blocks.22.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
593
+ "model.visual.blocks.22.norm1.bias": "model-00004-of-00004.safetensors",
594
+ "model.visual.blocks.22.norm1.weight": "model-00004-of-00004.safetensors",
595
+ "model.visual.blocks.22.norm2.bias": "model-00004-of-00004.safetensors",
596
+ "model.visual.blocks.22.norm2.weight": "model-00004-of-00004.safetensors",
597
+ "model.visual.blocks.23.attn.proj.bias": "model-00004-of-00004.safetensors",
598
+ "model.visual.blocks.23.attn.proj.weight": "model-00004-of-00004.safetensors",
599
+ "model.visual.blocks.23.attn.qkv.bias": "model-00004-of-00004.safetensors",
600
+ "model.visual.blocks.23.attn.qkv.weight": "model-00004-of-00004.safetensors",
601
+ "model.visual.blocks.23.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
602
+ "model.visual.blocks.23.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
603
+ "model.visual.blocks.23.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
604
+ "model.visual.blocks.23.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
605
+ "model.visual.blocks.23.norm1.bias": "model-00004-of-00004.safetensors",
606
+ "model.visual.blocks.23.norm1.weight": "model-00004-of-00004.safetensors",
607
+ "model.visual.blocks.23.norm2.bias": "model-00004-of-00004.safetensors",
608
+ "model.visual.blocks.23.norm2.weight": "model-00004-of-00004.safetensors",
609
+ "model.visual.blocks.24.attn.proj.bias": "model-00004-of-00004.safetensors",
610
+ "model.visual.blocks.24.attn.proj.weight": "model-00004-of-00004.safetensors",
611
+ "model.visual.blocks.24.attn.qkv.bias": "model-00004-of-00004.safetensors",
612
+ "model.visual.blocks.24.attn.qkv.weight": "model-00004-of-00004.safetensors",
613
+ "model.visual.blocks.24.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
614
+ "model.visual.blocks.24.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
615
+ "model.visual.blocks.24.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
616
+ "model.visual.blocks.24.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
617
+ "model.visual.blocks.24.norm1.bias": "model-00004-of-00004.safetensors",
618
+ "model.visual.blocks.24.norm1.weight": "model-00004-of-00004.safetensors",
619
+ "model.visual.blocks.24.norm2.bias": "model-00004-of-00004.safetensors",
620
+ "model.visual.blocks.24.norm2.weight": "model-00004-of-00004.safetensors",
621
+ "model.visual.blocks.25.attn.proj.bias": "model-00004-of-00004.safetensors",
622
+ "model.visual.blocks.25.attn.proj.weight": "model-00004-of-00004.safetensors",
623
+ "model.visual.blocks.25.attn.qkv.bias": "model-00004-of-00004.safetensors",
624
+ "model.visual.blocks.25.attn.qkv.weight": "model-00004-of-00004.safetensors",
625
+ "model.visual.blocks.25.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
626
+ "model.visual.blocks.25.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
627
+ "model.visual.blocks.25.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
628
+ "model.visual.blocks.25.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
629
+ "model.visual.blocks.25.norm1.bias": "model-00004-of-00004.safetensors",
630
+ "model.visual.blocks.25.norm1.weight": "model-00004-of-00004.safetensors",
631
+ "model.visual.blocks.25.norm2.bias": "model-00004-of-00004.safetensors",
632
+ "model.visual.blocks.25.norm2.weight": "model-00004-of-00004.safetensors",
633
+ "model.visual.blocks.26.attn.proj.bias": "model-00004-of-00004.safetensors",
634
+ "model.visual.blocks.26.attn.proj.weight": "model-00004-of-00004.safetensors",
635
+ "model.visual.blocks.26.attn.qkv.bias": "model-00004-of-00004.safetensors",
636
+ "model.visual.blocks.26.attn.qkv.weight": "model-00004-of-00004.safetensors",
637
+ "model.visual.blocks.26.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
638
+ "model.visual.blocks.26.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
639
+ "model.visual.blocks.26.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
640
+ "model.visual.blocks.26.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
641
+ "model.visual.blocks.26.norm1.bias": "model-00004-of-00004.safetensors",
642
+ "model.visual.blocks.26.norm1.weight": "model-00004-of-00004.safetensors",
643
+ "model.visual.blocks.26.norm2.bias": "model-00004-of-00004.safetensors",
644
+ "model.visual.blocks.26.norm2.weight": "model-00004-of-00004.safetensors",
645
+ "model.visual.blocks.3.attn.proj.bias": "model-00004-of-00004.safetensors",
646
+ "model.visual.blocks.3.attn.proj.weight": "model-00004-of-00004.safetensors",
647
+ "model.visual.blocks.3.attn.qkv.bias": "model-00004-of-00004.safetensors",
648
+ "model.visual.blocks.3.attn.qkv.weight": "model-00004-of-00004.safetensors",
649
+ "model.visual.blocks.3.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
650
+ "model.visual.blocks.3.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
651
+ "model.visual.blocks.3.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
652
+ "model.visual.blocks.3.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
653
+ "model.visual.blocks.3.norm1.bias": "model-00004-of-00004.safetensors",
654
+ "model.visual.blocks.3.norm1.weight": "model-00004-of-00004.safetensors",
655
+ "model.visual.blocks.3.norm2.bias": "model-00004-of-00004.safetensors",
656
+ "model.visual.blocks.3.norm2.weight": "model-00004-of-00004.safetensors",
657
+ "model.visual.blocks.4.attn.proj.bias": "model-00004-of-00004.safetensors",
658
+ "model.visual.blocks.4.attn.proj.weight": "model-00004-of-00004.safetensors",
659
+ "model.visual.blocks.4.attn.qkv.bias": "model-00004-of-00004.safetensors",
660
+ "model.visual.blocks.4.attn.qkv.weight": "model-00004-of-00004.safetensors",
661
+ "model.visual.blocks.4.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
662
+ "model.visual.blocks.4.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
663
+ "model.visual.blocks.4.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
664
+ "model.visual.blocks.4.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
665
+ "model.visual.blocks.4.norm1.bias": "model-00004-of-00004.safetensors",
666
+ "model.visual.blocks.4.norm1.weight": "model-00004-of-00004.safetensors",
667
+ "model.visual.blocks.4.norm2.bias": "model-00004-of-00004.safetensors",
668
+ "model.visual.blocks.4.norm2.weight": "model-00004-of-00004.safetensors",
669
+ "model.visual.blocks.5.attn.proj.bias": "model-00004-of-00004.safetensors",
670
+ "model.visual.blocks.5.attn.proj.weight": "model-00004-of-00004.safetensors",
671
+ "model.visual.blocks.5.attn.qkv.bias": "model-00004-of-00004.safetensors",
672
+ "model.visual.blocks.5.attn.qkv.weight": "model-00004-of-00004.safetensors",
673
+ "model.visual.blocks.5.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
674
+ "model.visual.blocks.5.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
675
+ "model.visual.blocks.5.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
676
+ "model.visual.blocks.5.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
677
+ "model.visual.blocks.5.norm1.bias": "model-00004-of-00004.safetensors",
678
+ "model.visual.blocks.5.norm1.weight": "model-00004-of-00004.safetensors",
679
+ "model.visual.blocks.5.norm2.bias": "model-00004-of-00004.safetensors",
680
+ "model.visual.blocks.5.norm2.weight": "model-00004-of-00004.safetensors",
681
+ "model.visual.blocks.6.attn.proj.bias": "model-00004-of-00004.safetensors",
682
+ "model.visual.blocks.6.attn.proj.weight": "model-00004-of-00004.safetensors",
683
+ "model.visual.blocks.6.attn.qkv.bias": "model-00004-of-00004.safetensors",
684
+ "model.visual.blocks.6.attn.qkv.weight": "model-00004-of-00004.safetensors",
685
+ "model.visual.blocks.6.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
686
+ "model.visual.blocks.6.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
687
+ "model.visual.blocks.6.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
688
+ "model.visual.blocks.6.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
689
+ "model.visual.blocks.6.norm1.bias": "model-00004-of-00004.safetensors",
690
+ "model.visual.blocks.6.norm1.weight": "model-00004-of-00004.safetensors",
691
+ "model.visual.blocks.6.norm2.bias": "model-00004-of-00004.safetensors",
692
+ "model.visual.blocks.6.norm2.weight": "model-00004-of-00004.safetensors",
693
+ "model.visual.blocks.7.attn.proj.bias": "model-00004-of-00004.safetensors",
694
+ "model.visual.blocks.7.attn.proj.weight": "model-00004-of-00004.safetensors",
695
+ "model.visual.blocks.7.attn.qkv.bias": "model-00004-of-00004.safetensors",
696
+ "model.visual.blocks.7.attn.qkv.weight": "model-00004-of-00004.safetensors",
697
+ "model.visual.blocks.7.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
698
+ "model.visual.blocks.7.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
699
+ "model.visual.blocks.7.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
700
+ "model.visual.blocks.7.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
701
+ "model.visual.blocks.7.norm1.bias": "model-00004-of-00004.safetensors",
702
+ "model.visual.blocks.7.norm1.weight": "model-00004-of-00004.safetensors",
703
+ "model.visual.blocks.7.norm2.bias": "model-00004-of-00004.safetensors",
704
+ "model.visual.blocks.7.norm2.weight": "model-00004-of-00004.safetensors",
705
+ "model.visual.blocks.8.attn.proj.bias": "model-00004-of-00004.safetensors",
706
+ "model.visual.blocks.8.attn.proj.weight": "model-00004-of-00004.safetensors",
707
+ "model.visual.blocks.8.attn.qkv.bias": "model-00004-of-00004.safetensors",
708
+ "model.visual.blocks.8.attn.qkv.weight": "model-00004-of-00004.safetensors",
709
+ "model.visual.blocks.8.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
710
+ "model.visual.blocks.8.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
711
+ "model.visual.blocks.8.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
712
+ "model.visual.blocks.8.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
713
+ "model.visual.blocks.8.norm1.bias": "model-00004-of-00004.safetensors",
714
+ "model.visual.blocks.8.norm1.weight": "model-00004-of-00004.safetensors",
715
+ "model.visual.blocks.8.norm2.bias": "model-00004-of-00004.safetensors",
716
+ "model.visual.blocks.8.norm2.weight": "model-00004-of-00004.safetensors",
717
+ "model.visual.blocks.9.attn.proj.bias": "model-00004-of-00004.safetensors",
718
+ "model.visual.blocks.9.attn.proj.weight": "model-00004-of-00004.safetensors",
719
+ "model.visual.blocks.9.attn.qkv.bias": "model-00004-of-00004.safetensors",
720
+ "model.visual.blocks.9.attn.qkv.weight": "model-00004-of-00004.safetensors",
721
+ "model.visual.blocks.9.mlp.linear_fc1.bias": "model-00004-of-00004.safetensors",
722
+ "model.visual.blocks.9.mlp.linear_fc1.weight": "model-00004-of-00004.safetensors",
723
+ "model.visual.blocks.9.mlp.linear_fc2.bias": "model-00004-of-00004.safetensors",
724
+ "model.visual.blocks.9.mlp.linear_fc2.weight": "model-00004-of-00004.safetensors",
725
+ "model.visual.blocks.9.norm1.bias": "model-00004-of-00004.safetensors",
726
+ "model.visual.blocks.9.norm1.weight": "model-00004-of-00004.safetensors",
727
+ "model.visual.blocks.9.norm2.bias": "model-00004-of-00004.safetensors",
728
+ "model.visual.blocks.9.norm2.weight": "model-00004-of-00004.safetensors",
729
+ "model.visual.deepstack_merger_list.0.linear_fc1.bias": "model-00004-of-00004.safetensors",
730
+ "model.visual.deepstack_merger_list.0.linear_fc1.weight": "model-00004-of-00004.safetensors",
731
+ "model.visual.deepstack_merger_list.0.linear_fc2.bias": "model-00004-of-00004.safetensors",
732
+ "model.visual.deepstack_merger_list.0.linear_fc2.weight": "model-00004-of-00004.safetensors",
733
+ "model.visual.deepstack_merger_list.0.norm.bias": "model-00004-of-00004.safetensors",
734
+ "model.visual.deepstack_merger_list.0.norm.weight": "model-00004-of-00004.safetensors",
735
+ "model.visual.deepstack_merger_list.1.linear_fc1.bias": "model-00004-of-00004.safetensors",
736
+ "model.visual.deepstack_merger_list.1.linear_fc1.weight": "model-00004-of-00004.safetensors",
737
+ "model.visual.deepstack_merger_list.1.linear_fc2.bias": "model-00004-of-00004.safetensors",
738
+ "model.visual.deepstack_merger_list.1.linear_fc2.weight": "model-00004-of-00004.safetensors",
739
+ "model.visual.deepstack_merger_list.1.norm.bias": "model-00004-of-00004.safetensors",
740
+ "model.visual.deepstack_merger_list.1.norm.weight": "model-00004-of-00004.safetensors",
741
+ "model.visual.deepstack_merger_list.2.linear_fc1.bias": "model-00004-of-00004.safetensors",
742
+ "model.visual.deepstack_merger_list.2.linear_fc1.weight": "model-00004-of-00004.safetensors",
743
+ "model.visual.deepstack_merger_list.2.linear_fc2.bias": "model-00004-of-00004.safetensors",
744
+ "model.visual.deepstack_merger_list.2.linear_fc2.weight": "model-00004-of-00004.safetensors",
745
+ "model.visual.deepstack_merger_list.2.norm.bias": "model-00004-of-00004.safetensors",
746
+ "model.visual.deepstack_merger_list.2.norm.weight": "model-00004-of-00004.safetensors",
747
+ "model.visual.merger.linear_fc1.bias": "model-00004-of-00004.safetensors",
748
+ "model.visual.merger.linear_fc1.weight": "model-00004-of-00004.safetensors",
749
+ "model.visual.merger.linear_fc2.bias": "model-00004-of-00004.safetensors",
750
+ "model.visual.merger.linear_fc2.weight": "model-00004-of-00004.safetensors",
751
+ "model.visual.merger.norm.bias": "model-00004-of-00004.safetensors",
752
+ "model.visual.merger.norm.weight": "model-00004-of-00004.safetensors",
753
+ "model.visual.patch_embed.proj.bias": "model-00004-of-00004.safetensors",
754
+ "model.visual.patch_embed.proj.weight": "model-00004-of-00004.safetensors",
755
+ "model.visual.pos_embed.weight": "model-00004-of-00004.safetensors"
756
+ }
757
+ }
model_index.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ThinkGenPipeline",
3
+ "_diffusers_version": "0.34.1",
4
+ "scheduler": [
5
+ "scheduling_flow_match_euler_discrete",
6
+ "FlowMatchEulerDiscreteScheduler"
7
+ ],
8
+ "transformer": [
9
+ "transformer_thinkgen",
10
+ "ThinkGenTransformer2DModel"
11
+ ],
12
+ "vae": [
13
+ "diffusers",
14
+ "AutoencoderKL"
15
+ ],
16
+ "mllm": [
17
+ "transformers",
18
+ "Qwen3VLForConditionalGeneration"
19
+ ],
20
+ "processor": [
21
+ "transformers",
22
+ "Qwen3VLProcessor"
23
+ ]
24
+ }
processor/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
processor/chat_template.jinja ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set image_count = namespace(value=0) %}
2
+ {%- set video_count = namespace(value=0) %}
3
+ {%- macro render_content(content, do_vision_count) %}
4
+ {%- if content is string %}
5
+ {{- content }}
6
+ {%- else %}
7
+ {%- for item in content %}
8
+ {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
+ {%- if do_vision_count %}
10
+ {%- set image_count.value = image_count.value + 1 %}
11
+ {%- endif %}
12
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
13
+ <|vision_start|><|image_pad|><|vision_end|>
14
+ {%- elif 'video' in item or item.type == 'video' %}
15
+ {%- if do_vision_count %}
16
+ {%- set video_count.value = video_count.value + 1 %}
17
+ {%- endif %}
18
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
19
+ <|vision_start|><|video_pad|><|vision_end|>
20
+ {%- elif 'text' in item %}
21
+ {{- item.text }}
22
+ {%- endif %}
23
+ {%- endfor %}
24
+ {%- endif %}
25
+ {%- endmacro %}
26
+ {%- if tools %}
27
+ {{- '<|im_start|>system\n' }}
28
+ {%- if messages[0].role == 'system' %}
29
+ {{- render_content(messages[0].content, false) + '\n\n' }}
30
+ {%- endif %}
31
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
32
+ {%- for tool in tools %}
33
+ {{- "\n" }}
34
+ {{- tool | tojson }}
35
+ {%- endfor %}
36
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
37
+ {%- else %}
38
+ {%- if messages[0].role == 'system' %}
39
+ {{- '<|im_start|>system\n' + render_content(messages[0].content, false) + '<|im_end|>\n' }}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
43
+ {%- for message in messages[::-1] %}
44
+ {%- set index = (messages|length - 1) - loop.index0 %}
45
+ {%- if ns.multi_step_tool and message.role == "user" %}
46
+ {%- set content = render_content(message.content, false) %}
47
+ {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
48
+ {%- set ns.multi_step_tool = false %}
49
+ {%- set ns.last_query_index = index %}
50
+ {%- endif %}
51
+ {%- endif %}
52
+ {%- endfor %}
53
+ {%- for message in messages %}
54
+ {%- set content = render_content(message.content, True) %}
55
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
56
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
57
+ {%- elif message.role == "assistant" %}
58
+ {%- set reasoning_content = '' %}
59
+ {%- if message.reasoning_content is string %}
60
+ {%- set reasoning_content = message.reasoning_content %}
61
+ {%- else %}
62
+ {%- if '</think>' in content %}
63
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
64
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
65
+ {%- endif %}
66
+ {%- endif %}
67
+ {%- if loop.index0 > ns.last_query_index %}
68
+ {%- if loop.last or (not loop.last and reasoning_content) %}
69
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
70
+ {%- else %}
71
+ {{- '<|im_start|>' + message.role + '\n' + content }}
72
+ {%- endif %}
73
+ {%- else %}
74
+ {{- '<|im_start|>' + message.role + '\n' + content }}
75
+ {%- endif %}
76
+ {%- if message.tool_calls %}
77
+ {%- for tool_call in message.tool_calls %}
78
+ {%- if (loop.first and content) or (not loop.first) %}
79
+ {{- '\n' }}
80
+ {%- endif %}
81
+ {%- if tool_call.function %}
82
+ {%- set tool_call = tool_call.function %}
83
+ {%- endif %}
84
+ {{- '<tool_call>\n{"name": "' }}
85
+ {{- tool_call.name }}
86
+ {{- '", "arguments": ' }}
87
+ {%- if tool_call.arguments is string %}
88
+ {{- tool_call.arguments }}
89
+ {%- else %}
90
+ {{- tool_call.arguments | tojson }}
91
+ {%- endif %}
92
+ {{- '}\n</tool_call>' }}
93
+ {%- endfor %}
94
+ {%- endif %}
95
+ {{- '<|im_end|>\n' }}
96
+ {%- elif message.role == "tool" %}
97
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
98
+ {{- '<|im_start|>user' }}
99
+ {%- endif %}
100
+ {{- '\n<tool_response>\n' }}
101
+ {{- content }}
102
+ {{- '\n</tool_response>' }}
103
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
104
+ {{- '<|im_end|>\n' }}
105
+ {%- endif %}
106
+ {%- endif %}
107
+ {%- endfor %}
108
+ {%- if add_generation_prompt %}
109
+ {{- '<|im_start|>assistant\n<think>\n' }}
110
+ {%- endif %}
processor/chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are MiMo, an AI assistant developed by Xiaomi.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
+ }
processor/preprocessor_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "disable_grouping": null,
7
+ "do_center_crop": null,
8
+ "do_convert_rgb": true,
9
+ "do_normalize": true,
10
+ "do_pad": null,
11
+ "do_rescale": true,
12
+ "do_resize": true,
13
+ "image_mean": [
14
+ 0.5,
15
+ 0.5,
16
+ 0.5
17
+ ],
18
+ "image_processor_type": "Qwen2VLImageProcessorFast",
19
+ "image_std": [
20
+ 0.5,
21
+ 0.5,
22
+ 0.5
23
+ ],
24
+ "input_data_format": null,
25
+ "max_pixels": null,
26
+ "merge_size": 2,
27
+ "min_pixels": null,
28
+ "pad_size": null,
29
+ "patch_size": 16,
30
+ "processor_class": "Qwen3VLProcessor",
31
+ "resample": 3,
32
+ "rescale_factor": 0.00392156862745098,
33
+ "return_tensors": null,
34
+ "size": {
35
+ "longest_edge": 16777216,
36
+ "shortest_edge": 65536
37
+ },
38
+ "temporal_patch_size": 2
39
+ }
processor/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
processor/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
processor/tokenizer_config.json ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 262144,
235
+ "pad_token": "<|endoftext|>",
236
+ "processor_class": "Qwen3VLProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null,
240
+ "use_fast": true
241
+ }
processor/video_preprocessor_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "do_center_crop": null,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "do_sample_frames": true,
12
+ "fps": 2,
13
+ "image_mean": [
14
+ 0.5,
15
+ 0.5,
16
+ 0.5
17
+ ],
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "input_data_format": null,
24
+ "max_frames": 768,
25
+ "merge_size": 2,
26
+ "min_frames": 4,
27
+ "num_frames": null,
28
+ "pad_size": null,
29
+ "patch_size": 16,
30
+ "processor_class": "Qwen3VLProcessor",
31
+ "resample": 3,
32
+ "rescale_factor": 0.00392156862745098,
33
+ "return_metadata": false,
34
+ "size": {
35
+ "longest_edge": 25165824,
36
+ "shortest_edge": 4096
37
+ },
38
+ "temporal_patch_size": 2,
39
+ "video_metadata": null,
40
+ "video_processor_type": "Qwen3VLVideoProcessor"
41
+ }
processor/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.34.1",
4
+ "dynamic_time_shift": true,
5
+ "num_train_timesteps": 1000
6
+ }
scheduler/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ dynamic_time_shift: bool = False
69
+ ):
70
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
71
+
72
+ self.timesteps = timesteps
73
+
74
+ self._step_index = None
75
+ self._begin_index = None
76
+
77
+ @property
78
+ def step_index(self):
79
+ """
80
+ The index counter for current timestep. It will increase 1 after each scheduler step.
81
+ """
82
+ return self._step_index
83
+
84
+ @property
85
+ def begin_index(self):
86
+ """
87
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
88
+ """
89
+ return self._begin_index
90
+
91
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
92
+ def set_begin_index(self, begin_index: int = 0):
93
+ """
94
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
95
+
96
+ Args:
97
+ begin_index (`int`):
98
+ The begin index for the scheduler.
99
+ """
100
+ self._begin_index = begin_index
101
+
102
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
103
+ if schedule_timesteps is None:
104
+ schedule_timesteps = self._timesteps
105
+
106
+ indices = (schedule_timesteps == timestep).nonzero()
107
+
108
+ # The sigma index that is taken for the **very** first `step`
109
+ # is always the second index (or the last index if there is only 1)
110
+ # This way we can ensure we don't accidentally skip a sigma in
111
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
112
+ pos = 1 if len(indices) > 1 else 0
113
+
114
+ return indices[pos].item()
115
+
116
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
117
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
118
+
119
+ def set_timesteps(
120
+ self,
121
+ num_inference_steps: int = None,
122
+ device: Union[str, torch.device] = None,
123
+ timesteps: Optional[List[float]] = None,
124
+ num_tokens: Optional[int] = None
125
+ ):
126
+ """
127
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
128
+
129
+ Args:
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model.
132
+ device (`str` or `torch.device`, *optional*):
133
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
134
+ """
135
+
136
+ if timesteps is None:
137
+ self.num_inference_steps = num_inference_steps
138
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
139
+ if self.config.dynamic_time_shift and num_tokens is not None:
140
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
141
+ timesteps = timesteps / (m - m * timesteps + timesteps)
142
+
143
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
144
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
145
+
146
+ self.timesteps = timesteps
147
+ self._timesteps = _timesteps
148
+ self._step_index = None
149
+ self._begin_index = None
150
+
151
+ def _init_step_index(self, timestep):
152
+ if self.begin_index is None:
153
+ if isinstance(timestep, torch.Tensor):
154
+ timestep = timestep.to(self.timesteps.device)
155
+ self._step_index = self.index_for_timestep(timestep)
156
+ else:
157
+ self._step_index = self._begin_index
158
+
159
+ def step(
160
+ self,
161
+ model_output: torch.FloatTensor,
162
+ timestep: Union[float, torch.FloatTensor],
163
+ sample: torch.FloatTensor,
164
+ generator: Optional[torch.Generator] = None,
165
+ return_dict: bool = True,
166
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
167
+ """
168
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
+ process from the learned model outputs (most often the predicted noise).
170
+
171
+ Args:
172
+ model_output (`torch.FloatTensor`):
173
+ The direct output from learned diffusion model.
174
+ timestep (`float`):
175
+ The current discrete timestep in the diffusion chain.
176
+ sample (`torch.FloatTensor`):
177
+ A current instance of a sample created by the diffusion process.
178
+ s_churn (`float`):
179
+ s_tmin (`float`):
180
+ s_tmax (`float`):
181
+ s_noise (`float`, defaults to 1.0):
182
+ Scaling factor for noise added to the sample.
183
+ generator (`torch.Generator`, *optional*):
184
+ A random number generator.
185
+ return_dict (`bool`):
186
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
187
+ tuple.
188
+
189
+ Returns:
190
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
191
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
192
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
193
+ """
194
+
195
+ if (
196
+ isinstance(timestep, int)
197
+ or isinstance(timestep, torch.IntTensor)
198
+ or isinstance(timestep, torch.LongTensor)
199
+ ):
200
+ raise ValueError(
201
+ (
202
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204
+ " one of the `scheduler.timesteps` as a timestep."
205
+ ),
206
+ )
207
+
208
+ if self.step_index is None:
209
+ self._init_step_index(timestep)
210
+ # Upcast to avoid precision issues when computing prev_sample
211
+ sample = sample.to(torch.float32)
212
+ t = self._timesteps[self.step_index]
213
+ t_next = self._timesteps[self.step_index + 1]
214
+
215
+ prev_sample = sample + (t_next - t) * model_output
216
+
217
+ # Cast sample back to model compatible dtype
218
+ prev_sample = prev_sample.to(model_output.dtype)
219
+
220
+ # upon completion increase step index by one
221
+ self._step_index += 1
222
+
223
+ if not return_dict:
224
+ return (prev_sample,)
225
+
226
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
227
+
228
+ def __len__(self):
229
+ return self.config.num_train_timesteps
transformer/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ThinkGenTransformer2DModel",
3
+ "_diffusers_version": "0.34.1",
4
+ "axes_dim_rope": [
5
+ 40,
6
+ 40,
7
+ 40
8
+ ],
9
+ "axes_lens": [
10
+ 10000,
11
+ 10000,
12
+ 10000
13
+ ],
14
+ "ffn_dim_multiplier": null,
15
+ "hidden_size": 2520,
16
+ "in_channels": 16,
17
+ "multiple_of": 256,
18
+ "norm_eps": 1e-05,
19
+ "num_attention_heads": 21,
20
+ "num_kv_heads": 7,
21
+ "num_layers": 32,
22
+ "num_refiner_layers": 2,
23
+ "out_channels": null,
24
+ "patch_size": 2,
25
+ "text_feat_dim": 4096,
26
+ "timestep_scale": 1000.0
27
+ }
transformer/diffusion_pytorch_model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:322d806247df29cd16a1e8d7ddf307844ec13656172bf8e36d80562fcf8fb62f
3
+ size 9913126464
transformer/diffusion_pytorch_model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a41d01a025b8166aafe6fa1b605176f6be5136fd1f1e018ce426816b591ac33
3
+ size 6018290672
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15931355544
4
+ },
5
+ "weight_map": {
6
+ "context_refiner.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
7
+ "context_refiner.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
8
+ "context_refiner.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
9
+ "context_refiner.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
10
+ "context_refiner.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
11
+ "context_refiner.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
12
+ "context_refiner.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
13
+ "context_refiner.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
14
+ "context_refiner.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
15
+ "context_refiner.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
16
+ "context_refiner.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
17
+ "context_refiner.0.norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
18
+ "context_refiner.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
19
+ "context_refiner.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
20
+ "context_refiner.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
21
+ "context_refiner.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
22
+ "context_refiner.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
23
+ "context_refiner.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
24
+ "context_refiner.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
25
+ "context_refiner.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
26
+ "context_refiner.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
27
+ "context_refiner.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
28
+ "context_refiner.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
29
+ "context_refiner.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
30
+ "context_refiner.1.norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
31
+ "context_refiner.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
32
+ "image_index_embedding": "diffusion_pytorch_model-00001-of-00002.safetensors",
33
+ "layers.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
34
+ "layers.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
35
+ "layers.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
36
+ "layers.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
37
+ "layers.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
38
+ "layers.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
39
+ "layers.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
40
+ "layers.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
41
+ "layers.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
42
+ "layers.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
43
+ "layers.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
44
+ "layers.0.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
45
+ "layers.0.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
46
+ "layers.0.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
47
+ "layers.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
48
+ "layers.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
49
+ "layers.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
50
+ "layers.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
51
+ "layers.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
52
+ "layers.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
53
+ "layers.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
54
+ "layers.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
55
+ "layers.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
56
+ "layers.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
57
+ "layers.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
58
+ "layers.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
59
+ "layers.1.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
60
+ "layers.1.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
61
+ "layers.1.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
62
+ "layers.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
63
+ "layers.10.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
64
+ "layers.10.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
65
+ "layers.10.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
66
+ "layers.10.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
67
+ "layers.10.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
68
+ "layers.10.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
69
+ "layers.10.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
70
+ "layers.10.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
71
+ "layers.10.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
72
+ "layers.10.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
73
+ "layers.10.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
74
+ "layers.10.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
75
+ "layers.10.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
76
+ "layers.10.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
77
+ "layers.10.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
78
+ "layers.11.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
79
+ "layers.11.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
80
+ "layers.11.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
81
+ "layers.11.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
82
+ "layers.11.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
83
+ "layers.11.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
84
+ "layers.11.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
85
+ "layers.11.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
86
+ "layers.11.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
87
+ "layers.11.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
88
+ "layers.11.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
89
+ "layers.11.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
90
+ "layers.11.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
91
+ "layers.11.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
92
+ "layers.11.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
93
+ "layers.12.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
94
+ "layers.12.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
95
+ "layers.12.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
96
+ "layers.12.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
97
+ "layers.12.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
98
+ "layers.12.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
99
+ "layers.12.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
100
+ "layers.12.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
101
+ "layers.12.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
102
+ "layers.12.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
103
+ "layers.12.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
104
+ "layers.12.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
105
+ "layers.12.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
106
+ "layers.12.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
107
+ "layers.12.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
108
+ "layers.13.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
109
+ "layers.13.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
110
+ "layers.13.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
111
+ "layers.13.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
112
+ "layers.13.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
113
+ "layers.13.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
114
+ "layers.13.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
115
+ "layers.13.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
116
+ "layers.13.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
117
+ "layers.13.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
118
+ "layers.13.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
119
+ "layers.13.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
120
+ "layers.13.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
121
+ "layers.13.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
122
+ "layers.13.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
123
+ "layers.14.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
124
+ "layers.14.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
125
+ "layers.14.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
126
+ "layers.14.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
127
+ "layers.14.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
128
+ "layers.14.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
129
+ "layers.14.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
130
+ "layers.14.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
131
+ "layers.14.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
132
+ "layers.14.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
133
+ "layers.14.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
134
+ "layers.14.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
135
+ "layers.14.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
136
+ "layers.14.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
137
+ "layers.14.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
138
+ "layers.15.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
139
+ "layers.15.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
140
+ "layers.15.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
141
+ "layers.15.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
142
+ "layers.15.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
143
+ "layers.15.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
144
+ "layers.15.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
145
+ "layers.15.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
146
+ "layers.15.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
147
+ "layers.15.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
148
+ "layers.15.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
149
+ "layers.15.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
150
+ "layers.15.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
151
+ "layers.15.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
152
+ "layers.15.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
153
+ "layers.16.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
154
+ "layers.16.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
155
+ "layers.16.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
156
+ "layers.16.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
157
+ "layers.16.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
158
+ "layers.16.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
159
+ "layers.16.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
160
+ "layers.16.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
161
+ "layers.16.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
162
+ "layers.16.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
163
+ "layers.16.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
164
+ "layers.16.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
165
+ "layers.16.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
166
+ "layers.16.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
167
+ "layers.16.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
168
+ "layers.17.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
169
+ "layers.17.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
170
+ "layers.17.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
171
+ "layers.17.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
172
+ "layers.17.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
173
+ "layers.17.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
174
+ "layers.17.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
175
+ "layers.17.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
176
+ "layers.17.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
177
+ "layers.17.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
178
+ "layers.17.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
179
+ "layers.17.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
180
+ "layers.17.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
181
+ "layers.17.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
182
+ "layers.17.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
183
+ "layers.18.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
184
+ "layers.18.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
185
+ "layers.18.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
186
+ "layers.18.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
187
+ "layers.18.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
188
+ "layers.18.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
189
+ "layers.18.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
190
+ "layers.18.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
191
+ "layers.18.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
192
+ "layers.18.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
193
+ "layers.18.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
194
+ "layers.18.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
195
+ "layers.18.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
196
+ "layers.18.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
197
+ "layers.18.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
198
+ "layers.19.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
199
+ "layers.19.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
200
+ "layers.19.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
201
+ "layers.19.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
202
+ "layers.19.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
203
+ "layers.19.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
204
+ "layers.19.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
205
+ "layers.19.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
206
+ "layers.19.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
207
+ "layers.19.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
208
+ "layers.19.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
209
+ "layers.19.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
210
+ "layers.19.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
211
+ "layers.19.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
212
+ "layers.19.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
213
+ "layers.2.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
214
+ "layers.2.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
215
+ "layers.2.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
216
+ "layers.2.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
217
+ "layers.2.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
218
+ "layers.2.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
219
+ "layers.2.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
220
+ "layers.2.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
221
+ "layers.2.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
222
+ "layers.2.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
223
+ "layers.2.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
224
+ "layers.2.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
225
+ "layers.2.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
226
+ "layers.2.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
227
+ "layers.2.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
228
+ "layers.20.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
229
+ "layers.20.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
230
+ "layers.20.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
231
+ "layers.20.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
232
+ "layers.20.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
233
+ "layers.20.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
234
+ "layers.20.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
235
+ "layers.20.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
236
+ "layers.20.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
237
+ "layers.20.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
238
+ "layers.20.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
239
+ "layers.20.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
240
+ "layers.20.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
241
+ "layers.20.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
242
+ "layers.20.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
243
+ "layers.21.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
244
+ "layers.21.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
245
+ "layers.21.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
246
+ "layers.21.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
247
+ "layers.21.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
248
+ "layers.21.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
249
+ "layers.21.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
250
+ "layers.21.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
251
+ "layers.21.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
252
+ "layers.21.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
253
+ "layers.21.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
254
+ "layers.21.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
255
+ "layers.21.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
256
+ "layers.21.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
257
+ "layers.21.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
258
+ "layers.22.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
259
+ "layers.22.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
260
+ "layers.22.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
261
+ "layers.22.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
262
+ "layers.22.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
263
+ "layers.22.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
264
+ "layers.22.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
265
+ "layers.22.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
266
+ "layers.22.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
267
+ "layers.22.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
268
+ "layers.22.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
269
+ "layers.22.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
270
+ "layers.22.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
271
+ "layers.22.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
272
+ "layers.22.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
273
+ "layers.23.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
274
+ "layers.23.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
275
+ "layers.23.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
276
+ "layers.23.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
277
+ "layers.23.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
278
+ "layers.23.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
279
+ "layers.23.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
280
+ "layers.23.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
281
+ "layers.23.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
282
+ "layers.23.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
283
+ "layers.23.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
284
+ "layers.23.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
285
+ "layers.23.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
286
+ "layers.23.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
287
+ "layers.23.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
288
+ "layers.24.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
289
+ "layers.24.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
290
+ "layers.24.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
291
+ "layers.24.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
292
+ "layers.24.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
293
+ "layers.24.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
294
+ "layers.24.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
295
+ "layers.24.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
296
+ "layers.24.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
297
+ "layers.24.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
298
+ "layers.24.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
299
+ "layers.24.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
300
+ "layers.24.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
301
+ "layers.24.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
302
+ "layers.24.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
303
+ "layers.25.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
304
+ "layers.25.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
305
+ "layers.25.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
306
+ "layers.25.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
307
+ "layers.25.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
308
+ "layers.25.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
309
+ "layers.25.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
310
+ "layers.25.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
311
+ "layers.25.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
312
+ "layers.25.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
313
+ "layers.25.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
314
+ "layers.25.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
315
+ "layers.25.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
316
+ "layers.25.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
317
+ "layers.25.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
318
+ "layers.26.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
319
+ "layers.26.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
320
+ "layers.26.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
321
+ "layers.26.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
322
+ "layers.26.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
323
+ "layers.26.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
324
+ "layers.26.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
325
+ "layers.26.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
326
+ "layers.26.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
327
+ "layers.26.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
328
+ "layers.26.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
329
+ "layers.26.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
330
+ "layers.26.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
331
+ "layers.26.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
332
+ "layers.26.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
333
+ "layers.27.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
334
+ "layers.27.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
335
+ "layers.27.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
336
+ "layers.27.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
337
+ "layers.27.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
338
+ "layers.27.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
339
+ "layers.27.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
340
+ "layers.27.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
341
+ "layers.27.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
342
+ "layers.27.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
343
+ "layers.27.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
344
+ "layers.27.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
345
+ "layers.27.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
346
+ "layers.27.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
347
+ "layers.27.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
348
+ "layers.28.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
349
+ "layers.28.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
350
+ "layers.28.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
351
+ "layers.28.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
352
+ "layers.28.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
353
+ "layers.28.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
354
+ "layers.28.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
355
+ "layers.28.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
356
+ "layers.28.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
357
+ "layers.28.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
358
+ "layers.28.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
359
+ "layers.28.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
360
+ "layers.28.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
361
+ "layers.28.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
362
+ "layers.28.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
363
+ "layers.29.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
364
+ "layers.29.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
365
+ "layers.29.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
366
+ "layers.29.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
367
+ "layers.29.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
368
+ "layers.29.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
369
+ "layers.29.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
370
+ "layers.29.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
371
+ "layers.29.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
372
+ "layers.29.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
373
+ "layers.29.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
374
+ "layers.29.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
375
+ "layers.29.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
376
+ "layers.29.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
377
+ "layers.29.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
378
+ "layers.3.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
379
+ "layers.3.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
380
+ "layers.3.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
381
+ "layers.3.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
382
+ "layers.3.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
383
+ "layers.3.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
384
+ "layers.3.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
385
+ "layers.3.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
386
+ "layers.3.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
387
+ "layers.3.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
388
+ "layers.3.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
389
+ "layers.3.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
390
+ "layers.3.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
391
+ "layers.3.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
392
+ "layers.3.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
393
+ "layers.30.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
394
+ "layers.30.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
395
+ "layers.30.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
396
+ "layers.30.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
397
+ "layers.30.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
398
+ "layers.30.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
399
+ "layers.30.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
400
+ "layers.30.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
401
+ "layers.30.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
402
+ "layers.30.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
403
+ "layers.30.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
404
+ "layers.30.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
405
+ "layers.30.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
406
+ "layers.30.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
407
+ "layers.30.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
408
+ "layers.31.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
409
+ "layers.31.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
410
+ "layers.31.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
411
+ "layers.31.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
412
+ "layers.31.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
413
+ "layers.31.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
414
+ "layers.31.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
415
+ "layers.31.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
416
+ "layers.31.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
417
+ "layers.31.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
418
+ "layers.31.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
419
+ "layers.31.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
420
+ "layers.31.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
421
+ "layers.31.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
422
+ "layers.31.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
423
+ "layers.4.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
424
+ "layers.4.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
425
+ "layers.4.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
426
+ "layers.4.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
427
+ "layers.4.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
428
+ "layers.4.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
429
+ "layers.4.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
430
+ "layers.4.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
431
+ "layers.4.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
432
+ "layers.4.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
433
+ "layers.4.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
434
+ "layers.4.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
435
+ "layers.4.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
436
+ "layers.4.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
437
+ "layers.4.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
438
+ "layers.5.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
439
+ "layers.5.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
440
+ "layers.5.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
441
+ "layers.5.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
442
+ "layers.5.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
443
+ "layers.5.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
444
+ "layers.5.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
445
+ "layers.5.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
446
+ "layers.5.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
447
+ "layers.5.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
448
+ "layers.5.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
449
+ "layers.5.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
450
+ "layers.5.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
451
+ "layers.5.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
452
+ "layers.5.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
453
+ "layers.6.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
454
+ "layers.6.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
455
+ "layers.6.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
456
+ "layers.6.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
457
+ "layers.6.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
458
+ "layers.6.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
459
+ "layers.6.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
460
+ "layers.6.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
461
+ "layers.6.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
462
+ "layers.6.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
463
+ "layers.6.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
464
+ "layers.6.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
465
+ "layers.6.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
466
+ "layers.6.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
467
+ "layers.6.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
468
+ "layers.7.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
469
+ "layers.7.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
470
+ "layers.7.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
471
+ "layers.7.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
472
+ "layers.7.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
473
+ "layers.7.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
474
+ "layers.7.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
475
+ "layers.7.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
476
+ "layers.7.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
477
+ "layers.7.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
478
+ "layers.7.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
479
+ "layers.7.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
480
+ "layers.7.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
481
+ "layers.7.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
482
+ "layers.7.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
483
+ "layers.8.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
484
+ "layers.8.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
485
+ "layers.8.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
486
+ "layers.8.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
487
+ "layers.8.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
488
+ "layers.8.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
489
+ "layers.8.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
490
+ "layers.8.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
491
+ "layers.8.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
492
+ "layers.8.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
493
+ "layers.8.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
494
+ "layers.8.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
495
+ "layers.8.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
496
+ "layers.8.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
497
+ "layers.8.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
498
+ "layers.9.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
499
+ "layers.9.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
500
+ "layers.9.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
501
+ "layers.9.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
502
+ "layers.9.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
503
+ "layers.9.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
504
+ "layers.9.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
505
+ "layers.9.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
506
+ "layers.9.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
507
+ "layers.9.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
508
+ "layers.9.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
509
+ "layers.9.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
510
+ "layers.9.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
511
+ "layers.9.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
512
+ "layers.9.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
513
+ "noise_refiner.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
514
+ "noise_refiner.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
515
+ "noise_refiner.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
516
+ "noise_refiner.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
517
+ "noise_refiner.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
518
+ "noise_refiner.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
519
+ "noise_refiner.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
520
+ "noise_refiner.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
521
+ "noise_refiner.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
522
+ "noise_refiner.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
523
+ "noise_refiner.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
524
+ "noise_refiner.0.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
525
+ "noise_refiner.0.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
526
+ "noise_refiner.0.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
527
+ "noise_refiner.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
528
+ "noise_refiner.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
529
+ "noise_refiner.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
530
+ "noise_refiner.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
531
+ "noise_refiner.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
532
+ "noise_refiner.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
533
+ "noise_refiner.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
534
+ "noise_refiner.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
535
+ "noise_refiner.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
536
+ "noise_refiner.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
537
+ "noise_refiner.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
538
+ "noise_refiner.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
539
+ "noise_refiner.1.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
540
+ "noise_refiner.1.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
541
+ "noise_refiner.1.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
542
+ "noise_refiner.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
543
+ "norm_out.linear_1.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
544
+ "norm_out.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
545
+ "norm_out.linear_2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
546
+ "norm_out.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
547
+ "prepad_embed": "diffusion_pytorch_model-00001-of-00002.safetensors",
548
+ "prepad_mask": "diffusion_pytorch_model-00001-of-00002.safetensors",
549
+ "ref_image_patch_embedder.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
550
+ "ref_image_patch_embedder.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
551
+ "ref_image_refiner.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
552
+ "ref_image_refiner.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
553
+ "ref_image_refiner.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
554
+ "ref_image_refiner.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
555
+ "ref_image_refiner.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
556
+ "ref_image_refiner.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
557
+ "ref_image_refiner.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
558
+ "ref_image_refiner.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
559
+ "ref_image_refiner.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
560
+ "ref_image_refiner.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
561
+ "ref_image_refiner.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
562
+ "ref_image_refiner.0.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
563
+ "ref_image_refiner.0.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
564
+ "ref_image_refiner.0.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
565
+ "ref_image_refiner.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
566
+ "ref_image_refiner.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
567
+ "ref_image_refiner.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
568
+ "ref_image_refiner.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
569
+ "ref_image_refiner.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
570
+ "ref_image_refiner.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
571
+ "ref_image_refiner.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
572
+ "ref_image_refiner.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
573
+ "ref_image_refiner.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
574
+ "ref_image_refiner.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
575
+ "ref_image_refiner.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
576
+ "ref_image_refiner.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
577
+ "ref_image_refiner.1.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
578
+ "ref_image_refiner.1.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
579
+ "ref_image_refiner.1.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
580
+ "ref_image_refiner.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
581
+ "time_caption_embed.caption_embedder.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
582
+ "time_caption_embed.caption_embedder.1.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
583
+ "time_caption_embed.caption_embedder.1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
584
+ "time_caption_embed.timestep_embedder.linear_1.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
585
+ "time_caption_embed.timestep_embedder.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
586
+ "time_caption_embed.timestep_embedder.linear_2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
587
+ "time_caption_embed.timestep_embedder.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
588
+ "x_embedder.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
589
+ "x_embedder.weight": "diffusion_pytorch_model-00001-of-00002.safetensors"
590
+ }
591
+ }
transformer/transformer_thinkgen.py ADDED
@@ -0,0 +1,2457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import itertools
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from dataclasses import dataclass
5
+ import math
6
+ import numpy as np
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import PeftAdapterMixin
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.embeddings import Timesteps
25
+
26
+
27
+ import importlib.util
28
+ import sys
29
+
30
+ # The package importlib_metadata is in a different place, depending on the python version.
31
+ if sys.version_info < (3, 8):
32
+ import importlib_metadata
33
+ else:
34
+ import importlib.metadata as importlib_metadata
35
+
36
+ def _is_package_available(pkg_name: str):
37
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
38
+ pkg_version = "N/A"
39
+
40
+ if pkg_exists:
41
+ try:
42
+ pkg_version = importlib_metadata.version(pkg_name)
43
+ except (ImportError, importlib_metadata.PackageNotFoundError):
44
+ pkg_exists = False
45
+
46
+ return pkg_exists, pkg_version
47
+
48
+ _triton_available, _triton_version = _is_package_available("triton")
49
+ _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
50
+
51
+ def is_triton_available():
52
+ return _triton_available
53
+
54
+ def is_flash_attn_available():
55
+ return _flash_attn_available
56
+
57
+ if is_flash_attn_available():
58
+ from flash_attn import flash_attn_varlen_func
59
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
60
+ else:
61
+ warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
62
+
63
+
64
+ if is_triton_available():
65
+ # from ...ops.triton.layer_norm import RMSNorm
66
+ import triton
67
+ import triton.language as tl
68
+
69
+
70
+ from typing import Callable
71
+
72
+
73
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
74
+ def decorator(*args, **kwargs):
75
+ if cuda_amp_deprecated:
76
+ kwargs["device_type"] = "cuda"
77
+ return dec(*args, **kwargs)
78
+ return decorator
79
+
80
+
81
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
82
+ deprecated = True
83
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
84
+ else:
85
+ deprecated = False
86
+ from torch.cuda.amp import custom_fwd, custom_bwd
87
+
88
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
89
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
90
+
91
+
92
+ def triton_autotune_configs():
93
+ # Return configs with a valid warp count for the current device
94
+ configs=[]
95
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
96
+ max_threads_per_block=1024
97
+ # Default to warp size 32 if not defined by device
98
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
99
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
100
+ warp_count=1
101
+ while warp_count*warp_size <= max_threads_per_block:
102
+ configs.append(triton.Config({}, num_warps=warp_count))
103
+ warp_count*=2
104
+ return configs
105
+
106
+ @triton.autotune(
107
+ configs=triton_autotune_configs(),
108
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
109
+ )
110
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
111
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
112
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
113
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
114
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
115
+ @triton.jit
116
+ def _layer_norm_fwd_1pass_kernel(
117
+ X, # pointer to the input
118
+ Y, # pointer to the output
119
+ W, # pointer to the weights
120
+ B, # pointer to the biases
121
+ RESIDUAL, # pointer to the residual
122
+ X1,
123
+ W1,
124
+ B1,
125
+ Y1,
126
+ RESIDUAL_OUT, # pointer to the residual
127
+ ROWSCALE,
128
+ SEEDS, # Dropout seeds for each row
129
+ DROPOUT_MASK,
130
+ Mean, # pointer to the mean
131
+ Rstd, # pointer to the 1/std
132
+ stride_x_row, # how much to increase the pointer when moving by 1 row
133
+ stride_y_row,
134
+ stride_res_row,
135
+ stride_res_out_row,
136
+ stride_x1_row,
137
+ stride_y1_row,
138
+ M, # number of rows in X
139
+ N, # number of columns in X
140
+ eps, # epsilon to avoid division by zero
141
+ dropout_p, # Dropout probability
142
+ zero_centered_weight, # If true, add 1.0 to the weight
143
+ IS_RMS_NORM: tl.constexpr,
144
+ BLOCK_N: tl.constexpr,
145
+ HAS_RESIDUAL: tl.constexpr,
146
+ STORE_RESIDUAL_OUT: tl.constexpr,
147
+ HAS_BIAS: tl.constexpr,
148
+ HAS_DROPOUT: tl.constexpr,
149
+ STORE_DROPOUT_MASK: tl.constexpr,
150
+ HAS_ROWSCALE: tl.constexpr,
151
+ HAS_X1: tl.constexpr,
152
+ HAS_W1: tl.constexpr,
153
+ HAS_B1: tl.constexpr,
154
+ ):
155
+ # Map the program id to the row of X and Y it should compute.
156
+ row = tl.program_id(0)
157
+ X += row * stride_x_row
158
+ Y += row * stride_y_row
159
+ if HAS_RESIDUAL:
160
+ RESIDUAL += row * stride_res_row
161
+ if STORE_RESIDUAL_OUT:
162
+ RESIDUAL_OUT += row * stride_res_out_row
163
+ if HAS_X1:
164
+ X1 += row * stride_x1_row
165
+ if HAS_W1:
166
+ Y1 += row * stride_y1_row
167
+ # Compute mean and variance
168
+ cols = tl.arange(0, BLOCK_N)
169
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
170
+ if HAS_ROWSCALE:
171
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
172
+ x *= rowscale
173
+ if HAS_DROPOUT:
174
+ # Compute dropout mask
175
+ # 7 rounds is good enough, and reduces register pressure
176
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
177
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
178
+ if STORE_DROPOUT_MASK:
179
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
180
+ if HAS_X1:
181
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
182
+ if HAS_ROWSCALE:
183
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
184
+ x1 *= rowscale
185
+ if HAS_DROPOUT:
186
+ # Compute dropout mask
187
+ # 7 rounds is good enough, and reduces register pressure
188
+ keep_mask = (
189
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
190
+ )
191
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
192
+ if STORE_DROPOUT_MASK:
193
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
194
+ x += x1
195
+ if HAS_RESIDUAL:
196
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
197
+ x += residual
198
+ if STORE_RESIDUAL_OUT:
199
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
200
+ if not IS_RMS_NORM:
201
+ mean = tl.sum(x, axis=0) / N
202
+ tl.store(Mean + row, mean)
203
+ xbar = tl.where(cols < N, x - mean, 0.0)
204
+ var = tl.sum(xbar * xbar, axis=0) / N
205
+ else:
206
+ xbar = tl.where(cols < N, x, 0.0)
207
+ var = tl.sum(xbar * xbar, axis=0) / N
208
+ rstd = 1 / tl.sqrt(var + eps)
209
+ tl.store(Rstd + row, rstd)
210
+ # Normalize and apply linear transformation
211
+ mask = cols < N
212
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
213
+ if zero_centered_weight:
214
+ w += 1.0
215
+ if HAS_BIAS:
216
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
217
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
218
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
219
+ # Write output
220
+ tl.store(Y + cols, y, mask=mask)
221
+ if HAS_W1:
222
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
223
+ if zero_centered_weight:
224
+ w1 += 1.0
225
+ if HAS_B1:
226
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
227
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
228
+ tl.store(Y1 + cols, y1, mask=mask)
229
+
230
+
231
+ def _layer_norm_fwd(
232
+ x,
233
+ weight,
234
+ bias,
235
+ eps,
236
+ residual=None,
237
+ x1=None,
238
+ weight1=None,
239
+ bias1=None,
240
+ dropout_p=0.0,
241
+ rowscale=None,
242
+ out_dtype=None,
243
+ residual_dtype=None,
244
+ zero_centered_weight=False,
245
+ is_rms_norm=False,
246
+ return_dropout_mask=False,
247
+ out=None,
248
+ residual_out=None
249
+ ):
250
+ if residual is not None:
251
+ residual_dtype = residual.dtype
252
+ M, N = x.shape
253
+ assert x.stride(-1) == 1
254
+ if residual is not None:
255
+ assert residual.stride(-1) == 1
256
+ assert residual.shape == (M, N)
257
+ assert weight.shape == (N,)
258
+ assert weight.stride(-1) == 1
259
+ if bias is not None:
260
+ assert bias.stride(-1) == 1
261
+ assert bias.shape == (N,)
262
+ if x1 is not None:
263
+ assert x1.shape == x.shape
264
+ assert rowscale is None
265
+ assert x1.stride(-1) == 1
266
+ if weight1 is not None:
267
+ assert weight1.shape == (N,)
268
+ assert weight1.stride(-1) == 1
269
+ if bias1 is not None:
270
+ assert bias1.shape == (N,)
271
+ assert bias1.stride(-1) == 1
272
+ if rowscale is not None:
273
+ assert rowscale.is_contiguous()
274
+ assert rowscale.shape == (M,)
275
+ # allocate output
276
+ if out is None:
277
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
278
+ else:
279
+ assert out.shape == x.shape
280
+ assert out.stride(-1) == 1
281
+ if weight1 is not None:
282
+ y1 = torch.empty_like(out)
283
+ assert y1.stride(-1) == 1
284
+ else:
285
+ y1 = None
286
+ if (
287
+ residual is not None
288
+ or (residual_dtype is not None and residual_dtype != x.dtype)
289
+ or dropout_p > 0.0
290
+ or rowscale is not None
291
+ or x1 is not None
292
+ ):
293
+ if residual_out is None:
294
+ residual_out = torch.empty(
295
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
296
+ )
297
+ else:
298
+ assert residual_out.shape == x.shape
299
+ assert residual_out.stride(-1) == 1
300
+ else:
301
+ residual_out = None
302
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
303
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
304
+ if dropout_p > 0.0:
305
+ seeds = torch.randint(
306
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
307
+ )
308
+ else:
309
+ seeds = None
310
+ if return_dropout_mask and dropout_p > 0.0:
311
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
312
+ else:
313
+ dropout_mask = None
314
+ # Less than 64KB per feature: enqueue fused kernel
315
+ MAX_FUSED_SIZE = 65536 // x.element_size()
316
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
317
+ if N > BLOCK_N:
318
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
319
+ with torch.cuda.device(x.device.index):
320
+ _layer_norm_fwd_1pass_kernel[(M,)](
321
+ x,
322
+ out,
323
+ weight,
324
+ bias,
325
+ residual,
326
+ x1,
327
+ weight1,
328
+ bias1,
329
+ y1,
330
+ residual_out,
331
+ rowscale,
332
+ seeds,
333
+ dropout_mask,
334
+ mean,
335
+ rstd,
336
+ x.stride(0),
337
+ out.stride(0),
338
+ residual.stride(0) if residual is not None else 0,
339
+ residual_out.stride(0) if residual_out is not None else 0,
340
+ x1.stride(0) if x1 is not None else 0,
341
+ y1.stride(0) if y1 is not None else 0,
342
+ M,
343
+ N,
344
+ eps,
345
+ dropout_p,
346
+ zero_centered_weight,
347
+ is_rms_norm,
348
+ BLOCK_N,
349
+ residual is not None,
350
+ residual_out is not None,
351
+ bias is not None,
352
+ dropout_p > 0.0,
353
+ dropout_mask is not None,
354
+ rowscale is not None,
355
+ )
356
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
357
+ if dropout_mask is not None and x1 is not None:
358
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
359
+ else:
360
+ dropout_mask1 = None
361
+ return (
362
+ out,
363
+ y1,
364
+ mean,
365
+ rstd,
366
+ residual_out if residual_out is not None else x,
367
+ seeds,
368
+ dropout_mask,
369
+ dropout_mask1,
370
+ )
371
+
372
+ @triton.autotune(
373
+ configs=triton_autotune_configs(),
374
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
375
+ )
376
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
377
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
378
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
379
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
380
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
381
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
382
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
383
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
384
+ @triton.jit
385
+ def _layer_norm_bwd_kernel(
386
+ X, # pointer to the input
387
+ W, # pointer to the weights
388
+ B, # pointer to the biases
389
+ Y, # pointer to the output to be recomputed
390
+ DY, # pointer to the output gradient
391
+ DX, # pointer to the input gradient
392
+ DW, # pointer to the partial sum of weights gradient
393
+ DB, # pointer to the partial sum of biases gradient
394
+ DRESIDUAL,
395
+ W1,
396
+ DY1,
397
+ DX1,
398
+ DW1,
399
+ DB1,
400
+ DRESIDUAL_IN,
401
+ ROWSCALE,
402
+ SEEDS,
403
+ Mean, # pointer to the mean
404
+ Rstd, # pointer to the 1/std
405
+ stride_x_row, # how much to increase the pointer when moving by 1 row
406
+ stride_y_row,
407
+ stride_dy_row,
408
+ stride_dx_row,
409
+ stride_dres_row,
410
+ stride_dy1_row,
411
+ stride_dx1_row,
412
+ stride_dres_in_row,
413
+ M, # number of rows in X
414
+ N, # number of columns in X
415
+ eps, # epsilon to avoid division by zero
416
+ dropout_p,
417
+ zero_centered_weight,
418
+ rows_per_program,
419
+ IS_RMS_NORM: tl.constexpr,
420
+ BLOCK_N: tl.constexpr,
421
+ HAS_DRESIDUAL: tl.constexpr,
422
+ STORE_DRESIDUAL: tl.constexpr,
423
+ HAS_BIAS: tl.constexpr,
424
+ HAS_DROPOUT: tl.constexpr,
425
+ HAS_ROWSCALE: tl.constexpr,
426
+ HAS_DY1: tl.constexpr,
427
+ HAS_DX1: tl.constexpr,
428
+ HAS_B1: tl.constexpr,
429
+ RECOMPUTE_OUTPUT: tl.constexpr,
430
+ ):
431
+ # Map the program id to the elements of X, DX, and DY it should compute.
432
+ row_block_id = tl.program_id(0)
433
+ row_start = row_block_id * rows_per_program
434
+ # Do not early exit if row_start >= M, because we need to write DW and DB
435
+ cols = tl.arange(0, BLOCK_N)
436
+ mask = cols < N
437
+ X += row_start * stride_x_row
438
+ if HAS_DRESIDUAL:
439
+ DRESIDUAL += row_start * stride_dres_row
440
+ if STORE_DRESIDUAL:
441
+ DRESIDUAL_IN += row_start * stride_dres_in_row
442
+ DY += row_start * stride_dy_row
443
+ DX += row_start * stride_dx_row
444
+ if HAS_DY1:
445
+ DY1 += row_start * stride_dy1_row
446
+ if HAS_DX1:
447
+ DX1 += row_start * stride_dx1_row
448
+ if RECOMPUTE_OUTPUT:
449
+ Y += row_start * stride_y_row
450
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
451
+ if zero_centered_weight:
452
+ w += 1.0
453
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
454
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
455
+ if HAS_DY1:
456
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
457
+ if zero_centered_weight:
458
+ w1 += 1.0
459
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
460
+ if HAS_BIAS:
461
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
462
+ if HAS_DY1:
463
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
464
+ if HAS_B1:
465
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
466
+ row_end = min((row_block_id + 1) * rows_per_program, M)
467
+ for row in range(row_start, row_end):
468
+ # Load data to SRAM
469
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
470
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
471
+ if HAS_DY1:
472
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
473
+ if not IS_RMS_NORM:
474
+ mean = tl.load(Mean + row)
475
+ rstd = tl.load(Rstd + row)
476
+ # Compute dx
477
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
478
+ xhat = tl.where(mask, xhat, 0.0)
479
+ if RECOMPUTE_OUTPUT:
480
+ y = xhat * w + b if HAS_BIAS else xhat * w
481
+ tl.store(Y + cols, y, mask=mask)
482
+ wdy = w * dy
483
+ dw += dy * xhat
484
+ if HAS_BIAS:
485
+ db += dy
486
+ if HAS_DY1:
487
+ wdy += w1 * dy1
488
+ dw1 += dy1 * xhat
489
+ if HAS_B1:
490
+ db1 += dy1
491
+ if not IS_RMS_NORM:
492
+ c1 = tl.sum(xhat * wdy, axis=0) / N
493
+ c2 = tl.sum(wdy, axis=0) / N
494
+ dx = (wdy - (xhat * c1 + c2)) * rstd
495
+ else:
496
+ c1 = tl.sum(xhat * wdy, axis=0) / N
497
+ dx = (wdy - xhat * c1) * rstd
498
+ if HAS_DRESIDUAL:
499
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
500
+ dx += dres
501
+ # Write dx
502
+ if STORE_DRESIDUAL:
503
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
504
+ if HAS_DX1:
505
+ if HAS_DROPOUT:
506
+ keep_mask = (
507
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
508
+ )
509
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
510
+ else:
511
+ dx1 = dx
512
+ tl.store(DX1 + cols, dx1, mask=mask)
513
+ if HAS_DROPOUT:
514
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
515
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
516
+ if HAS_ROWSCALE:
517
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
518
+ dx *= rowscale
519
+ tl.store(DX + cols, dx, mask=mask)
520
+
521
+ X += stride_x_row
522
+ if HAS_DRESIDUAL:
523
+ DRESIDUAL += stride_dres_row
524
+ if STORE_DRESIDUAL:
525
+ DRESIDUAL_IN += stride_dres_in_row
526
+ if RECOMPUTE_OUTPUT:
527
+ Y += stride_y_row
528
+ DY += stride_dy_row
529
+ DX += stride_dx_row
530
+ if HAS_DY1:
531
+ DY1 += stride_dy1_row
532
+ if HAS_DX1:
533
+ DX1 += stride_dx1_row
534
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
535
+ if HAS_BIAS:
536
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
537
+ if HAS_DY1:
538
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
539
+ if HAS_B1:
540
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
541
+
542
+
543
+ def _layer_norm_bwd(
544
+ dy,
545
+ x,
546
+ weight,
547
+ bias,
548
+ eps,
549
+ mean,
550
+ rstd,
551
+ dresidual=None,
552
+ dy1=None,
553
+ weight1=None,
554
+ bias1=None,
555
+ seeds=None,
556
+ dropout_p=0.0,
557
+ rowscale=None,
558
+ has_residual=False,
559
+ has_x1=False,
560
+ zero_centered_weight=False,
561
+ is_rms_norm=False,
562
+ x_dtype=None,
563
+ recompute_output=False,
564
+ ):
565
+ M, N = x.shape
566
+ assert x.stride(-1) == 1
567
+ assert dy.stride(-1) == 1
568
+ assert dy.shape == (M, N)
569
+ if dresidual is not None:
570
+ assert dresidual.stride(-1) == 1
571
+ assert dresidual.shape == (M, N)
572
+ assert weight.shape == (N,)
573
+ assert weight.stride(-1) == 1
574
+ if bias is not None:
575
+ assert bias.stride(-1) == 1
576
+ assert bias.shape == (N,)
577
+ if dy1 is not None:
578
+ assert weight1 is not None
579
+ assert dy1.shape == dy.shape
580
+ assert dy1.stride(-1) == 1
581
+ if weight1 is not None:
582
+ assert weight1.shape == (N,)
583
+ assert weight1.stride(-1) == 1
584
+ if bias1 is not None:
585
+ assert bias1.shape == (N,)
586
+ assert bias1.stride(-1) == 1
587
+ if seeds is not None:
588
+ assert seeds.is_contiguous()
589
+ assert seeds.shape == (M if not has_x1 else M * 2,)
590
+ if rowscale is not None:
591
+ assert rowscale.is_contiguous()
592
+ assert rowscale.shape == (M,)
593
+ # allocate output
594
+ dx = (
595
+ torch.empty_like(x)
596
+ if x_dtype is None
597
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
598
+ )
599
+ dresidual_in = (
600
+ torch.empty_like(x)
601
+ if has_residual
602
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
603
+ else None
604
+ )
605
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
606
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
607
+ if recompute_output:
608
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
609
+
610
+ # Less than 64KB per feature: enqueue fused kernel
611
+ MAX_FUSED_SIZE = 65536 // x.element_size()
612
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
613
+ if N > BLOCK_N:
614
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
615
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
616
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
617
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
618
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
619
+ _db = (
620
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
621
+ if bias is not None
622
+ else None
623
+ )
624
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
625
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
626
+ rows_per_program = math.ceil(M / sm_count)
627
+ grid = (sm_count,)
628
+ with torch.cuda.device(x.device.index):
629
+ _layer_norm_bwd_kernel[grid](
630
+ x,
631
+ weight,
632
+ bias,
633
+ y,
634
+ dy,
635
+ dx,
636
+ _dw,
637
+ _db,
638
+ dresidual,
639
+ weight1,
640
+ dy1,
641
+ dx1,
642
+ _dw1,
643
+ _db1,
644
+ dresidual_in,
645
+ rowscale,
646
+ seeds,
647
+ mean,
648
+ rstd,
649
+ x.stride(0),
650
+ 0 if not recompute_output else y.stride(0),
651
+ dy.stride(0),
652
+ dx.stride(0),
653
+ dresidual.stride(0) if dresidual is not None else 0,
654
+ dy1.stride(0) if dy1 is not None else 0,
655
+ dx1.stride(0) if dx1 is not None else 0,
656
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
657
+ M,
658
+ N,
659
+ eps,
660
+ dropout_p,
661
+ zero_centered_weight,
662
+ rows_per_program,
663
+ is_rms_norm,
664
+ BLOCK_N,
665
+ dresidual is not None,
666
+ dresidual_in is not None,
667
+ bias is not None,
668
+ dropout_p > 0.0,
669
+ )
670
+ dw = _dw.sum(0).to(weight.dtype)
671
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
672
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
673
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
674
+ # Don't need to compute dresidual_in separately in this case
675
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
676
+ dresidual_in = dx
677
+ if has_x1 and dropout_p == 0.0:
678
+ dx1 = dx
679
+ return (
680
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
681
+ if not recompute_output
682
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
683
+ )
684
+
685
+ class LayerNormFn(torch.autograd.Function):
686
+ @staticmethod
687
+ def forward(
688
+ ctx,
689
+ x,
690
+ weight,
691
+ bias,
692
+ residual=None,
693
+ x1=None,
694
+ weight1=None,
695
+ bias1=None,
696
+ eps=1e-6,
697
+ dropout_p=0.0,
698
+ rowscale=None,
699
+ prenorm=False,
700
+ residual_in_fp32=False,
701
+ zero_centered_weight=False,
702
+ is_rms_norm=False,
703
+ return_dropout_mask=False,
704
+ out=None,
705
+ residual_out=None
706
+ ):
707
+ x_shape_og = x.shape
708
+ # Check for zero sequence length
709
+ if x.numel() == 0:
710
+ ctx.zero_seq_length = True
711
+ # Only save minimal required tensors for backward
712
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
713
+ ctx.x_shape_og = x_shape_og
714
+ ctx.weight_shape = weight.shape
715
+ ctx.weight_dtype = weight.dtype
716
+ ctx.weight_device = weight.device
717
+
718
+ ctx.has_bias = bias is not None
719
+ ctx.bias_shape = bias.shape if bias is not None else None
720
+ ctx.bias_dtype = bias.dtype if bias is not None else None
721
+ ctx.bias_device = bias.device if bias is not None else None
722
+
723
+ ctx.has_weight1 = weight1 is not None
724
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
725
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
726
+ ctx.weight1_device = weight1.device if weight1 is not None else None
727
+
728
+ ctx.has_bias1 = bias1 is not None
729
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
730
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
731
+ ctx.bias1_device = bias1.device if bias1 is not None else None
732
+
733
+ ctx.has_residual = residual is not None
734
+ ctx.has_x1 = x1 is not None
735
+ ctx.dropout_p = dropout_p
736
+
737
+ # Handle output tensors with correct dtype
738
+ y = x # Preserve input tensor properties
739
+ y1 = torch.empty_like(x) if x1 is not None else None
740
+
741
+ # Only create residual_out if prenorm is True
742
+ residual_out = torch.empty(x.shape,
743
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
744
+ device=x.device) if prenorm else None
745
+
746
+ # Handle dropout masks
747
+ dropout_mask = None
748
+ dropout_mask1 = None
749
+ if return_dropout_mask:
750
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
751
+ if x1 is not None:
752
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
753
+
754
+ # Return based on configuration
755
+ if not return_dropout_mask:
756
+ if weight1 is None:
757
+ return y if not prenorm else (y, residual_out)
758
+ else:
759
+ return (y, y1) if not prenorm else (y, y1, residual_out)
760
+ else:
761
+ if weight1 is None:
762
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
763
+ else (y, residual_out, dropout_mask, dropout_mask1))
764
+ else:
765
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
766
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
767
+
768
+ ctx.zero_seq_length = False
769
+ # reshape input data into 2D tensor
770
+ x = x.reshape(-1, x.shape[-1])
771
+ if x.stride(-1) != 1:
772
+ x = x.contiguous()
773
+ if residual is not None:
774
+ assert residual.shape == x_shape_og
775
+ residual = residual.reshape(-1, residual.shape[-1])
776
+ if residual.stride(-1) != 1:
777
+ residual = residual.contiguous()
778
+ if x1 is not None:
779
+ assert x1.shape == x_shape_og
780
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
781
+ x1 = x1.reshape(-1, x1.shape[-1])
782
+ if x1.stride(-1) != 1:
783
+ x1 = x1.contiguous()
784
+ weight = weight.contiguous()
785
+ if bias is not None:
786
+ bias = bias.contiguous()
787
+ if weight1 is not None:
788
+ weight1 = weight1.contiguous()
789
+ if bias1 is not None:
790
+ bias1 = bias1.contiguous()
791
+ if rowscale is not None:
792
+ rowscale = rowscale.reshape(-1).contiguous()
793
+ residual_dtype = (
794
+ residual.dtype
795
+ if residual is not None
796
+ else (torch.float32 if residual_in_fp32 else None)
797
+ )
798
+ if out is not None:
799
+ out = out.reshape(-1, out.shape[-1])
800
+ if residual_out is not None:
801
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
802
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
803
+ x,
804
+ weight,
805
+ bias,
806
+ eps,
807
+ residual,
808
+ x1,
809
+ weight1,
810
+ bias1,
811
+ dropout_p=dropout_p,
812
+ rowscale=rowscale,
813
+ residual_dtype=residual_dtype,
814
+ zero_centered_weight=zero_centered_weight,
815
+ is_rms_norm=is_rms_norm,
816
+ return_dropout_mask=return_dropout_mask,
817
+ out=out,
818
+ residual_out=residual_out
819
+ )
820
+ ctx.save_for_backward(
821
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
822
+ )
823
+ ctx.x_shape_og = x_shape_og
824
+ ctx.eps = eps
825
+ ctx.dropout_p = dropout_p
826
+ ctx.is_rms_norm = is_rms_norm
827
+ ctx.has_residual = residual is not None
828
+ ctx.has_x1 = x1 is not None
829
+ ctx.prenorm = prenorm
830
+ ctx.x_dtype = x.dtype
831
+ ctx.zero_centered_weight = zero_centered_weight
832
+ y = y.reshape(x_shape_og)
833
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
834
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
835
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
836
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
837
+ if not return_dropout_mask:
838
+ if weight1 is None:
839
+ return y if not prenorm else (y, residual_out)
840
+ else:
841
+ return (y, y1) if not prenorm else (y, y1, residual_out)
842
+ else:
843
+ if weight1 is None:
844
+ return (
845
+ (y, dropout_mask, dropout_mask1)
846
+ if not prenorm
847
+ else (y, residual_out, dropout_mask, dropout_mask1)
848
+ )
849
+ else:
850
+ return (
851
+ (y, y1, dropout_mask, dropout_mask1)
852
+ if not prenorm
853
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
854
+ )
855
+
856
+ @staticmethod
857
+ def backward(ctx, dy, *args):
858
+ if ctx.zero_seq_length:
859
+ return (
860
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
861
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
862
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
863
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
864
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
865
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
866
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
867
+ None,
868
+ None,
869
+ None,
870
+ None,
871
+ None,
872
+ None,
873
+ None,
874
+ None,
875
+ None,
876
+ None,
877
+ )
878
+
879
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
880
+ dy = dy.reshape(-1, dy.shape[-1])
881
+ if dy.stride(-1) != 1:
882
+ dy = dy.contiguous()
883
+ assert dy.shape == x.shape
884
+ if weight1 is not None:
885
+ dy1, args = args[0], args[1:]
886
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
887
+ if dy1.stride(-1) != 1:
888
+ dy1 = dy1.contiguous()
889
+ assert dy1.shape == x.shape
890
+ else:
891
+ dy1 = None
892
+ if ctx.prenorm:
893
+ dresidual = args[0]
894
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
895
+ if dresidual.stride(-1) != 1:
896
+ dresidual = dresidual.contiguous()
897
+ assert dresidual.shape == x.shape
898
+ else:
899
+ dresidual = None
900
+
901
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
902
+ dy,
903
+ x,
904
+ weight,
905
+ bias,
906
+ ctx.eps,
907
+ mean,
908
+ rstd,
909
+ dresidual,
910
+ dy1,
911
+ weight1,
912
+ bias1,
913
+ seeds,
914
+ ctx.dropout_p,
915
+ rowscale,
916
+ ctx.has_residual,
917
+ ctx.has_x1,
918
+ ctx.zero_centered_weight,
919
+ ctx.is_rms_norm,
920
+ x_dtype=ctx.x_dtype,
921
+ )
922
+ return (
923
+ dx.reshape(ctx.x_shape_og),
924
+ dw,
925
+ db,
926
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
927
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
928
+ dw1,
929
+ db1,
930
+ None,
931
+ None,
932
+ None,
933
+ None,
934
+ None,
935
+ None,
936
+ None,
937
+ None,
938
+ None,
939
+ None,
940
+ )
941
+
942
+ def rms_norm_fn(
943
+ x,
944
+ weight,
945
+ bias,
946
+ residual=None,
947
+ x1=None,
948
+ weight1=None,
949
+ bias1=None,
950
+ eps=1e-6,
951
+ dropout_p=0.0,
952
+ rowscale=None,
953
+ prenorm=False,
954
+ residual_in_fp32=False,
955
+ zero_centered_weight=False,
956
+ return_dropout_mask=False,
957
+ out=None,
958
+ residual_out=None
959
+ ):
960
+ return LayerNormFn.apply(
961
+ x,
962
+ weight,
963
+ bias,
964
+ residual,
965
+ x1,
966
+ weight1,
967
+ bias1,
968
+ eps,
969
+ dropout_p,
970
+ rowscale,
971
+ prenorm,
972
+ residual_in_fp32,
973
+ zero_centered_weight,
974
+ True,
975
+ return_dropout_mask,
976
+ out,
977
+ residual_out
978
+ )
979
+
980
+ class RMSNorm(torch.nn.Module):
981
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
982
+ device=None, dtype=None):
983
+ factory_kwargs = {"device": device, "dtype": dtype}
984
+ super().__init__()
985
+ self.eps = eps
986
+ if dropout_p > 0.0:
987
+ self.drop = torch.nn.Dropout(dropout_p)
988
+ else:
989
+ self.drop = None
990
+ self.zero_centered_weight = zero_centered_weight
991
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
992
+ self.register_parameter("bias", None)
993
+ self.reset_parameters()
994
+
995
+ def reset_parameters(self):
996
+ if not self.zero_centered_weight:
997
+ torch.nn.init.ones_(self.weight)
998
+ else:
999
+ torch.nn.init.zeros_(self.weight)
1000
+
1001
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1002
+ return rms_norm_fn(
1003
+ x,
1004
+ self.weight,
1005
+ self.bias,
1006
+ residual=residual,
1007
+ eps=self.eps,
1008
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1009
+ prenorm=prenorm,
1010
+ residual_in_fp32=residual_in_fp32,
1011
+ zero_centered_weight=self.zero_centered_weight,
1012
+ )
1013
+ else:
1014
+ from torch.nn import RMSNorm
1015
+ warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
1016
+
1017
+ def swiglu(x, y):
1018
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
1019
+
1020
+ logger = logging.get_logger(__name__)
1021
+
1022
+ @dataclass
1023
+ class TeaCacheParams:
1024
+ previous_residual: Optional[torch.Tensor] = None
1025
+ previous_modulated_inp: Optional[torch.Tensor] = None
1026
+ accumulated_rel_l1_distance: float = 0
1027
+ is_first_or_last_step: bool = False
1028
+
1029
+
1030
+ class TimestepEmbedding(nn.Module):
1031
+ def __init__(
1032
+ self,
1033
+ in_channels: int,
1034
+ time_embed_dim: int,
1035
+ act_fn: str = "silu",
1036
+ out_dim: int = None,
1037
+ post_act_fn: Optional[str] = None,
1038
+ cond_proj_dim=None,
1039
+ sample_proj_bias=True,
1040
+ ):
1041
+ super().__init__()
1042
+
1043
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
1044
+
1045
+ if cond_proj_dim is not None:
1046
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
1047
+ else:
1048
+ self.cond_proj = None
1049
+
1050
+ self.act = get_activation(act_fn)
1051
+
1052
+ if out_dim is not None:
1053
+ time_embed_dim_out = out_dim
1054
+ else:
1055
+ time_embed_dim_out = time_embed_dim
1056
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
1057
+
1058
+ if post_act_fn is None:
1059
+ self.post_act = None
1060
+ else:
1061
+ self.post_act = get_activation(post_act_fn)
1062
+
1063
+ self.initialize_weights()
1064
+
1065
+ def initialize_weights(self):
1066
+ nn.init.normal_(self.linear_1.weight, std=0.02)
1067
+ nn.init.zeros_(self.linear_1.bias)
1068
+ nn.init.normal_(self.linear_2.weight, std=0.02)
1069
+ nn.init.zeros_(self.linear_2.bias)
1070
+
1071
+ def forward(self, sample, condition=None):
1072
+ if condition is not None:
1073
+ sample = sample + self.cond_proj(condition)
1074
+ sample = self.linear_1(sample)
1075
+
1076
+ if self.act is not None:
1077
+ sample = self.act(sample)
1078
+
1079
+ sample = self.linear_2(sample)
1080
+
1081
+ if self.post_act is not None:
1082
+ sample = self.post_act(sample)
1083
+ return sample
1084
+
1085
+ def apply_rotary_emb(
1086
+ x: torch.Tensor,
1087
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
1088
+ use_real: bool = True,
1089
+ use_real_unbind_dim: int = -1,
1090
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1091
+ """
1092
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
1093
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
1094
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
1095
+ tensors contain rotary embeddings and are returned as real tensors.
1096
+
1097
+ Args:
1098
+ x (`torch.Tensor`):
1099
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
1100
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
1101
+
1102
+ Returns:
1103
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
1104
+ """
1105
+ if use_real:
1106
+ cos, sin = freqs_cis # [S, D]
1107
+ cos = cos[None, None]
1108
+ sin = sin[None, None]
1109
+ cos, sin = cos.to(x.device), sin.to(x.device)
1110
+
1111
+ if use_real_unbind_dim == -1:
1112
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1113
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1114
+ elif use_real_unbind_dim == -2:
1115
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1116
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1117
+ else:
1118
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
1119
+
1120
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
1121
+
1122
+ return out
1123
+ else:
1124
+ # used for lumina
1125
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1126
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
1127
+ freqs_cis = freqs_cis.unsqueeze(2)
1128
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
1129
+
1130
+ return x_out.type_as(x)
1131
+
1132
+ class ThinkGenRotaryPosEmbed(nn.Module):
1133
+ def __init__(self, theta: int,
1134
+ axes_dim: Tuple[int, int, int],
1135
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1136
+ patch_size: int = 2):
1137
+ super().__init__()
1138
+ self.theta = theta
1139
+ self.axes_dim = axes_dim
1140
+ self.axes_lens = axes_lens
1141
+ self.patch_size = patch_size
1142
+
1143
+ @staticmethod
1144
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
1145
+ axes_lens: Tuple[int, int, int],
1146
+ theta: int) -> List[torch.Tensor]:
1147
+ freqs_cis = []
1148
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
1149
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
1150
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
1151
+ freqs_cis.append(emb)
1152
+ return freqs_cis
1153
+
1154
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
1155
+ device = ids.device
1156
+ if ids.device.type == "mps":
1157
+ ids = ids.to("cpu")
1158
+
1159
+ result = []
1160
+ for i in range(len(self.axes_dim)):
1161
+ freqs = freqs_cis[i].to(ids.device)
1162
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
1163
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
1164
+ return torch.cat(result, dim=-1).to(device)
1165
+
1166
+ def forward(
1167
+ self,
1168
+ freqs_cis,
1169
+ attention_mask,
1170
+ l_effective_ref_img_len,
1171
+ l_effective_img_len,
1172
+ ref_img_sizes,
1173
+ img_sizes,
1174
+ device
1175
+ ):
1176
+ batch_size = len(attention_mask)
1177
+ p = self.patch_size
1178
+
1179
+ encoder_seq_len = attention_mask.shape[1]
1180
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
1181
+
1182
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
1183
+
1184
+ max_seq_len = max(seq_lengths)
1185
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1186
+ max_img_len = max(l_effective_img_len)
1187
+
1188
+ # Create position IDs
1189
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
1190
+
1191
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
1192
+ # add text position ids
1193
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
1194
+
1195
+ pe_shift = cap_seq_len
1196
+ pe_shift_len = cap_seq_len
1197
+
1198
+ if ref_img_sizes[i] is not None:
1199
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
1200
+ H, W = ref_img_size
1201
+ ref_H_tokens, ref_W_tokens = H // p, W // p
1202
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
1203
+ # add image position ids
1204
+
1205
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
1206
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
1207
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
1208
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
1209
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
1210
+
1211
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
1212
+ pe_shift_len += ref_img_len
1213
+
1214
+ H, W = img_sizes[i]
1215
+ H_tokens, W_tokens = H // p, W // p
1216
+ assert H_tokens * W_tokens == l_effective_img_len[i]
1217
+
1218
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
1219
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
1220
+
1221
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
1222
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
1223
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
1224
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
1225
+
1226
+ # Get combined rotary embeddings
1227
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
1228
+
1229
+ # create separate rotary embeddings for captions and images
1230
+ cap_freqs_cis = torch.zeros(
1231
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1232
+ )
1233
+ ref_img_freqs_cis = torch.zeros(
1234
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1235
+ )
1236
+ img_freqs_cis = torch.zeros(
1237
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1238
+ )
1239
+
1240
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
1241
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
1242
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
1243
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
1244
+
1245
+ return (
1246
+ cap_freqs_cis,
1247
+ ref_img_freqs_cis,
1248
+ img_freqs_cis,
1249
+ freqs_cis,
1250
+ l_effective_cap_len,
1251
+ seq_lengths,
1252
+ )
1253
+
1254
+
1255
+ class LuminaRMSNormZero(nn.Module):
1256
+ """
1257
+ Norm layer adaptive RMS normalization zero.
1258
+
1259
+ Parameters:
1260
+ embedding_dim (`int`): The size of each embedding vector.
1261
+ """
1262
+
1263
+ def __init__(
1264
+ self,
1265
+ embedding_dim: int,
1266
+ norm_eps: float,
1267
+ norm_elementwise_affine: bool,
1268
+ ):
1269
+ super().__init__()
1270
+ self.silu = nn.SiLU()
1271
+ self.linear = nn.Linear(
1272
+ min(embedding_dim, 1024),
1273
+ 4 * embedding_dim,
1274
+ bias=True,
1275
+ )
1276
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
1277
+
1278
+ def forward(
1279
+ self,
1280
+ x: torch.Tensor,
1281
+ emb: Optional[torch.Tensor] = None,
1282
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1283
+ emb = self.linear(self.silu(emb))
1284
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
1285
+ x = self.norm(x) * (1 + scale_msa[:, None])
1286
+ return x, gate_msa, scale_mlp, gate_mlp
1287
+
1288
+
1289
+ class LuminaLayerNormContinuous(nn.Module):
1290
+ def __init__(
1291
+ self,
1292
+ embedding_dim: int,
1293
+ conditioning_embedding_dim: int,
1294
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
1295
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
1296
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
1297
+ # However, this is how it was implemented in the original code, and it's rather likely you should
1298
+ # set `elementwise_affine` to False.
1299
+ elementwise_affine=True,
1300
+ eps=1e-5,
1301
+ bias=True,
1302
+ norm_type="layer_norm",
1303
+ out_dim: Optional[int] = None,
1304
+ ):
1305
+ super().__init__()
1306
+
1307
+ # AdaLN
1308
+ self.silu = nn.SiLU()
1309
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
1310
+
1311
+ if norm_type == "layer_norm":
1312
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
1313
+ elif norm_type == "rms_norm":
1314
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
1315
+ else:
1316
+ raise ValueError(f"unknown norm_type {norm_type}")
1317
+
1318
+ self.linear_2 = None
1319
+ if out_dim is not None:
1320
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
1321
+
1322
+ def forward(
1323
+ self,
1324
+ x: torch.Tensor,
1325
+ conditioning_embedding: torch.Tensor,
1326
+ ) -> torch.Tensor:
1327
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
1328
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
1329
+ scale = emb
1330
+ x = self.norm(x) * (1 + scale)[:, None, :]
1331
+
1332
+ if self.linear_2 is not None:
1333
+ x = self.linear_2(x)
1334
+
1335
+ return x
1336
+
1337
+
1338
+ class LuminaFeedForward(nn.Module):
1339
+ r"""
1340
+ A feed-forward layer.
1341
+
1342
+ Parameters:
1343
+ hidden_size (`int`):
1344
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
1345
+ hidden representations.
1346
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
1347
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
1348
+ of this value.
1349
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
1350
+ dimension. Defaults to None.
1351
+ """
1352
+
1353
+ def __init__(
1354
+ self,
1355
+ dim: int,
1356
+ inner_dim: int,
1357
+ multiple_of: Optional[int] = 256,
1358
+ ffn_dim_multiplier: Optional[float] = None,
1359
+ ):
1360
+ super().__init__()
1361
+
1362
+ self.swiglu = swiglu
1363
+
1364
+ # custom hidden_size factor multiplier
1365
+ if ffn_dim_multiplier is not None:
1366
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
1367
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
1368
+
1369
+ self.linear_1 = nn.Linear(
1370
+ dim,
1371
+ inner_dim,
1372
+ bias=False,
1373
+ )
1374
+ self.linear_2 = nn.Linear(
1375
+ inner_dim,
1376
+ dim,
1377
+ bias=False,
1378
+ )
1379
+ self.linear_3 = nn.Linear(
1380
+ dim,
1381
+ inner_dim,
1382
+ bias=False,
1383
+ )
1384
+
1385
+ def forward(self, x):
1386
+ h1, h2 = self.linear_1(x), self.linear_3(x)
1387
+ return self.linear_2(self.swiglu(h1, h2))
1388
+
1389
+
1390
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
1391
+ def __init__(
1392
+ self,
1393
+ hidden_size: int = 4096,
1394
+ text_feat_dim: int = 204800, # 2048
1395
+ frequency_embedding_size: int = 256,
1396
+ norm_eps: float = 1e-5,
1397
+ timestep_scale: float = 1.0,
1398
+ ) -> None:
1399
+ super().__init__()
1400
+
1401
+ self.time_proj = Timesteps(
1402
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
1403
+ )
1404
+
1405
+ self.timestep_embedder = TimestepEmbedding(
1406
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
1407
+ )
1408
+
1409
+ self.caption_embedder = nn.Sequential(
1410
+ RMSNorm(text_feat_dim*2, eps=norm_eps),
1411
+ nn.Linear(text_feat_dim*2, hidden_size, bias=True),
1412
+ )
1413
+
1414
+ self._initialize_weights()
1415
+
1416
+ def _initialize_weights(self):
1417
+ for name, module in self.caption_embedder.named_modules():
1418
+ if hasattr(module, 'weight') and module.weight is not None:
1419
+ nn.init.trunc_normal_(module.weight, std=0.02)
1420
+ print(name, "a")
1421
+ if hasattr(module, 'bias') and module.bias is not None:
1422
+ nn.init.zeros_(module.bias)
1423
+ print(name, "b")
1424
+
1425
+ print("init caption_embedder done")
1426
+
1427
+
1428
+ def forward(
1429
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
1430
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1431
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
1432
+ time_embed = self.timestep_embedder(timestep_proj)
1433
+ caption_embed = self.caption_embedder(text_hidden_states)
1434
+ return time_embed, caption_embed
1435
+
1436
+
1437
+ class ThinkGenAttnProcessor:
1438
+ """
1439
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
1440
+
1441
+ This processor is optimized for PyTorch 2.0 and implements:
1442
+ - Flash attention with variable length sequences
1443
+ - Rotary position embeddings (RoPE)
1444
+ - Query-Key normalization
1445
+ - Proportional attention scaling
1446
+
1447
+ Args:
1448
+ None
1449
+
1450
+ Raises:
1451
+ ImportError: If PyTorch version is less than 2.0
1452
+ """
1453
+
1454
+ def __init__(self) -> None:
1455
+ """Initialize the attention processor."""
1456
+ if not hasattr(F, "scaled_dot_product_attention"):
1457
+ raise ImportError(
1458
+ "ThinkGenAttnProcessorFlash2Varlen requires PyTorch 2.0. "
1459
+ "Please upgrade PyTorch to version 2.0 or later."
1460
+ )
1461
+
1462
+ def __call__(
1463
+ self,
1464
+ attn: Attention,
1465
+ hidden_states: torch.Tensor,
1466
+ encoder_hidden_states: torch.Tensor,
1467
+ attention_mask: Optional[torch.Tensor] = None,
1468
+ image_rotary_emb: Optional[torch.Tensor] = None,
1469
+ base_sequence_length: Optional[int] = None,
1470
+ ) -> torch.Tensor:
1471
+ """
1472
+ Process attention computation with flash attention.
1473
+
1474
+ Args:
1475
+ attn: Attention module
1476
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
1477
+ encoder_hidden_states: Encoder hidden states tensor
1478
+ attention_mask: Optional attention mask tensor
1479
+ image_rotary_emb: Optional rotary embeddings for image tokens
1480
+ base_sequence_length: Optional base sequence length for proportional attention
1481
+
1482
+ Returns:
1483
+ torch.Tensor: Processed hidden states after attention computation
1484
+ """
1485
+ batch_size, sequence_length, _ = hidden_states.shape
1486
+
1487
+ # Get Query-Key-Value Pair
1488
+ query = attn.to_q(hidden_states)
1489
+ key = attn.to_k(encoder_hidden_states)
1490
+ value = attn.to_v(encoder_hidden_states)
1491
+
1492
+ query_dim = query.shape[-1]
1493
+ inner_dim = key.shape[-1]
1494
+ head_dim = query_dim // attn.heads
1495
+ dtype = query.dtype
1496
+
1497
+ # Get key-value heads
1498
+ kv_heads = inner_dim // head_dim
1499
+
1500
+ # Reshape tensors for attention computation
1501
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1502
+ key = key.view(batch_size, -1, kv_heads, head_dim)
1503
+ value = value.view(batch_size, -1, kv_heads, head_dim)
1504
+
1505
+ # Apply Query-Key normalization
1506
+ if attn.norm_q is not None:
1507
+ query = attn.norm_q(query)
1508
+ if attn.norm_k is not None:
1509
+ key = attn.norm_k(key)
1510
+
1511
+ # Apply Rotary Position Embeddings
1512
+ if image_rotary_emb is not None:
1513
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
1514
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
1515
+
1516
+ query, key = query.to(dtype), key.to(dtype)
1517
+
1518
+ # Calculate attention scale
1519
+ if base_sequence_length is not None:
1520
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
1521
+ else:
1522
+ softmax_scale = attn.scale
1523
+
1524
+ # scaled_dot_product_attention expects attention_mask shape to be
1525
+ # (batch, heads, source_length, target_length)
1526
+ if attention_mask is not None:
1527
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
1528
+
1529
+ query = query.transpose(1, 2)
1530
+ key = key.transpose(1, 2)
1531
+ value = value.transpose(1, 2)
1532
+
1533
+ # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
1534
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
1535
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
1536
+
1537
+ hidden_states = F.scaled_dot_product_attention(
1538
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
1539
+ )
1540
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1541
+ hidden_states = hidden_states.type_as(query)
1542
+
1543
+ # Apply output projection
1544
+ hidden_states = attn.to_out[0](hidden_states)
1545
+ hidden_states = attn.to_out[1](hidden_states)
1546
+
1547
+ return hidden_states
1548
+
1549
+
1550
+
1551
+ class ThinkGenAttnProcessorFlash2Varlen:
1552
+ """
1553
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
1554
+
1555
+ This processor implements:
1556
+ - Flash attention with variable length sequences
1557
+ - Rotary position embeddings (RoPE)
1558
+ - Query-Key normalization
1559
+ - Proportional attention scaling
1560
+
1561
+ Args:
1562
+ None
1563
+ """
1564
+
1565
+ def __init__(self) -> None:
1566
+ """Initialize the attention processor."""
1567
+ if not is_flash_attn_available():
1568
+ raise ImportError(
1569
+ "ThinkGenAttnProcessorFlash2Varlen requires flash_attn. "
1570
+ "Please install flash_attn."
1571
+ )
1572
+
1573
+ def _upad_input(
1574
+ self,
1575
+ query_layer: torch.Tensor,
1576
+ key_layer: torch.Tensor,
1577
+ value_layer: torch.Tensor,
1578
+ attention_mask: torch.Tensor,
1579
+ query_length: int,
1580
+ num_heads: int,
1581
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
1582
+ """
1583
+ Unpad the input tensors for flash attention.
1584
+
1585
+ Args:
1586
+ query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
1587
+ key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
1588
+ value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
1589
+ attention_mask: Attention mask tensor of shape (batch_size, seq_len)
1590
+ query_length: Length of the query sequence
1591
+ num_heads: Number of attention heads
1592
+
1593
+ Returns:
1594
+ Tuple containing:
1595
+ - Unpadded query tensor
1596
+ - Unpadded key tensor
1597
+ - Unpadded value tensor
1598
+ - Query indices
1599
+ - Tuple of cumulative sequence lengths for query and key
1600
+ - Tuple of maximum sequence lengths for query and key
1601
+ """
1602
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
1603
+ """Helper function to get unpadding data from attention mask."""
1604
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
1605
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
1606
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
1607
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
1608
+ return indices, cu_seqlens, max_seqlen_in_batch
1609
+
1610
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1611
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1612
+
1613
+ # Unpad key and value layers
1614
+ key_layer = index_first_axis(
1615
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1616
+ indices_k,
1617
+ )
1618
+ value_layer = index_first_axis(
1619
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1620
+ indices_k,
1621
+ )
1622
+
1623
+ # Handle different query length cases
1624
+ if query_length == kv_seq_len:
1625
+ query_layer = index_first_axis(
1626
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
1627
+ indices_k,
1628
+ )
1629
+ cu_seqlens_q = cu_seqlens_k
1630
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1631
+ indices_q = indices_k
1632
+ elif query_length == 1:
1633
+ max_seqlen_in_batch_q = 1
1634
+ cu_seqlens_q = torch.arange(
1635
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1636
+ )
1637
+ indices_q = cu_seqlens_q[:-1]
1638
+ query_layer = query_layer.squeeze(1)
1639
+ else:
1640
+ attention_mask = attention_mask[:, -query_length:]
1641
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
1642
+
1643
+ return (
1644
+ query_layer,
1645
+ key_layer,
1646
+ value_layer,
1647
+ indices_q,
1648
+ (cu_seqlens_q, cu_seqlens_k),
1649
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1650
+ )
1651
+
1652
+ def __call__(
1653
+ self,
1654
+ attn: Attention,
1655
+ hidden_states: torch.Tensor,
1656
+ encoder_hidden_states: torch.Tensor,
1657
+ attention_mask: Optional[torch.Tensor] = None,
1658
+ image_rotary_emb: Optional[torch.Tensor] = None,
1659
+ base_sequence_length: Optional[int] = None,
1660
+ ) -> torch.Tensor:
1661
+ """
1662
+ Process attention computation with flash attention.
1663
+
1664
+ Args:
1665
+ attn: Attention module
1666
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
1667
+ encoder_hidden_states: Encoder hidden states tensor
1668
+ attention_mask: Optional attention mask tensor
1669
+ image_rotary_emb: Optional rotary embeddings for image tokens
1670
+ base_sequence_length: Optional base sequence length for proportional attention
1671
+
1672
+ Returns:
1673
+ torch.Tensor: Processed hidden states after attention computation
1674
+ """
1675
+ batch_size, sequence_length, _ = hidden_states.shape
1676
+
1677
+ # Get Query-Key-Value Pair
1678
+ query = attn.to_q(hidden_states)
1679
+ key = attn.to_k(encoder_hidden_states)
1680
+ value = attn.to_v(encoder_hidden_states)
1681
+
1682
+ query_dim = query.shape[-1]
1683
+ inner_dim = key.shape[-1]
1684
+ head_dim = query_dim // attn.heads
1685
+ dtype = query.dtype
1686
+
1687
+ # Get key-value heads
1688
+ kv_heads = inner_dim // head_dim
1689
+
1690
+ # Reshape tensors for attention computation
1691
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1692
+ key = key.view(batch_size, -1, kv_heads, head_dim)
1693
+ value = value.view(batch_size, -1, kv_heads, head_dim)
1694
+
1695
+ # Apply Query-Key normalization
1696
+ if attn.norm_q is not None:
1697
+ query = attn.norm_q(query)
1698
+ if attn.norm_k is not None:
1699
+ key = attn.norm_k(key)
1700
+
1701
+ # Apply Rotary Position Embeddings
1702
+ if image_rotary_emb is not None:
1703
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
1704
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
1705
+
1706
+ query, key = query.to(dtype), key.to(dtype)
1707
+
1708
+ # Calculate attention scale
1709
+ if base_sequence_length is not None:
1710
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
1711
+ else:
1712
+ softmax_scale = attn.scale
1713
+
1714
+ # Unpad input for flash attention
1715
+ (
1716
+ query_states,
1717
+ key_states,
1718
+ value_states,
1719
+ indices_q,
1720
+ cu_seq_lens,
1721
+ max_seq_lens,
1722
+ ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
1723
+
1724
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1725
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1726
+
1727
+ # Handle different number of heads
1728
+ if kv_heads < attn.heads:
1729
+ key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
1730
+ value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
1731
+
1732
+ # Apply flash attention
1733
+ attn_output_unpad = flash_attn_varlen_func(
1734
+ query_states,
1735
+ key_states,
1736
+ value_states,
1737
+ cu_seqlens_q=cu_seqlens_q,
1738
+ cu_seqlens_k=cu_seqlens_k,
1739
+ max_seqlen_q=max_seqlen_in_batch_q,
1740
+ max_seqlen_k=max_seqlen_in_batch_k,
1741
+ dropout_p=0.0,
1742
+ causal=False,
1743
+ softmax_scale=softmax_scale,
1744
+ )
1745
+
1746
+ # Pad output and apply final transformations
1747
+ hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
1748
+ hidden_states = hidden_states.flatten(-2)
1749
+ hidden_states = hidden_states.type_as(query)
1750
+
1751
+ # Apply output projection
1752
+ hidden_states = attn.to_out[0](hidden_states)
1753
+ hidden_states = attn.to_out[1](hidden_states)
1754
+
1755
+ return hidden_states
1756
+
1757
+
1758
+ class ThinkGenTransformerBlock(nn.Module):
1759
+ """
1760
+ Transformer block for ThinkGen model.
1761
+
1762
+ This block implements a transformer layer with:
1763
+ - Multi-head attention with flash attention
1764
+ - Feed-forward network with SwiGLU activation
1765
+ - RMS normalization
1766
+ - Optional modulation for conditional generation
1767
+
1768
+ Args:
1769
+ dim: Dimension of the input and output tensors
1770
+ num_attention_heads: Number of attention heads
1771
+ num_kv_heads: Number of key-value heads
1772
+ multiple_of: Multiple of which the hidden dimension should be
1773
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
1774
+ norm_eps: Epsilon value for normalization layers
1775
+ modulation: Whether to use modulation for conditional generation
1776
+ use_fused_rms_norm: Whether to use fused RMS normalization
1777
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1778
+ """
1779
+
1780
+ def __init__(
1781
+ self,
1782
+ dim: int,
1783
+ num_attention_heads: int,
1784
+ num_kv_heads: int,
1785
+ multiple_of: int,
1786
+ ffn_dim_multiplier: float,
1787
+ norm_eps: float,
1788
+ modulation: bool = True,
1789
+ ) -> None:
1790
+ """Initialize the transformer block."""
1791
+ super().__init__()
1792
+ self.head_dim = dim // num_attention_heads
1793
+ self.modulation = modulation
1794
+
1795
+ try:
1796
+ processor = ThinkGenAttnProcessorFlash2Varlen()
1797
+ except ImportError:
1798
+ processor = ThinkGenAttnProcessor()
1799
+
1800
+ # Initialize attention layer
1801
+ self.attn = Attention(
1802
+ query_dim=dim,
1803
+ cross_attention_dim=None,
1804
+ dim_head=dim // num_attention_heads,
1805
+ qk_norm="rms_norm",
1806
+ heads=num_attention_heads,
1807
+ kv_heads=num_kv_heads,
1808
+ eps=1e-5,
1809
+ bias=False,
1810
+ out_bias=False,
1811
+ processor=processor,
1812
+ )
1813
+
1814
+ # Initialize feed-forward network
1815
+ self.feed_forward = LuminaFeedForward(
1816
+ dim=dim,
1817
+ inner_dim=4 * dim,
1818
+ multiple_of=multiple_of,
1819
+ ffn_dim_multiplier=ffn_dim_multiplier
1820
+ )
1821
+
1822
+ # Initialize normalization layers
1823
+ if modulation:
1824
+ self.norm1 = LuminaRMSNormZero(
1825
+ embedding_dim=dim,
1826
+ norm_eps=norm_eps,
1827
+ norm_elementwise_affine=True
1828
+ )
1829
+ else:
1830
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
1831
+
1832
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
1833
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
1834
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
1835
+
1836
+ self.initialize_weights()
1837
+
1838
+ def initialize_weights(self) -> None:
1839
+ """
1840
+ Initialize the weights of the transformer block.
1841
+
1842
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
1843
+ """
1844
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
1845
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
1846
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
1847
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
1848
+
1849
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
1850
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
1851
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
1852
+
1853
+ if self.modulation:
1854
+ nn.init.zeros_(self.norm1.linear.weight)
1855
+ nn.init.zeros_(self.norm1.linear.bias)
1856
+
1857
+ def forward(
1858
+ self,
1859
+ hidden_states: torch.Tensor,
1860
+ attention_mask: torch.Tensor,
1861
+ image_rotary_emb: torch.Tensor,
1862
+ temb: Optional[torch.Tensor] = None,
1863
+ ) -> torch.Tensor:
1864
+ """
1865
+ Forward pass of the transformer block.
1866
+
1867
+ Args:
1868
+ hidden_states: Input hidden states tensor
1869
+ attention_mask: Attention mask tensor
1870
+ image_rotary_emb: Rotary embeddings for image tokens
1871
+ temb: Optional timestep embedding tensor
1872
+
1873
+ Returns:
1874
+ torch.Tensor: Output hidden states after transformer block processing
1875
+ """
1876
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
1877
+ if enable_taylorseer:
1878
+ if self.modulation:
1879
+ if temb is None:
1880
+ raise ValueError("temb must be provided when modulation is enabled")
1881
+
1882
+ if self.current['type'] == 'full':
1883
+ self.current['module'] = 'total'
1884
+ taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
1885
+
1886
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
1887
+ attn_output = self.attn(
1888
+ hidden_states=norm_hidden_states,
1889
+ encoder_hidden_states=norm_hidden_states,
1890
+ attention_mask=attention_mask,
1891
+ image_rotary_emb=image_rotary_emb,
1892
+ )
1893
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
1894
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
1895
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
1896
+
1897
+ derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
1898
+
1899
+ elif self.current['type'] == 'Taylor':
1900
+ self.current['module'] = 'total'
1901
+ hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
1902
+ else:
1903
+ norm_hidden_states = self.norm1(hidden_states)
1904
+ attn_output = self.attn(
1905
+ hidden_states=norm_hidden_states,
1906
+ encoder_hidden_states=norm_hidden_states,
1907
+ attention_mask=attention_mask,
1908
+ image_rotary_emb=image_rotary_emb,
1909
+ )
1910
+ hidden_states = hidden_states + self.norm2(attn_output)
1911
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
1912
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
1913
+ else:
1914
+ if self.modulation:
1915
+ if temb is None:
1916
+ raise ValueError("temb must be provided when modulation is enabled")
1917
+
1918
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
1919
+ attn_output = self.attn(
1920
+ hidden_states=norm_hidden_states,
1921
+ encoder_hidden_states=norm_hidden_states,
1922
+ attention_mask=attention_mask,
1923
+ image_rotary_emb=image_rotary_emb,
1924
+ )
1925
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
1926
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
1927
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
1928
+ else:
1929
+ norm_hidden_states = self.norm1(hidden_states)
1930
+ attn_output = self.attn(
1931
+ hidden_states=norm_hidden_states,
1932
+ encoder_hidden_states=norm_hidden_states,
1933
+ attention_mask=attention_mask,
1934
+ image_rotary_emb=image_rotary_emb,
1935
+ )
1936
+ hidden_states = hidden_states + self.norm2(attn_output)
1937
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
1938
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
1939
+
1940
+ return hidden_states
1941
+
1942
+
1943
+ class ThinkGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1944
+ """
1945
+ ThinkGen Transformer 2D Model.
1946
+
1947
+ A transformer-based diffusion model for image generation with:
1948
+ - Patch-based image processing
1949
+ - Rotary position embeddings
1950
+ - Multi-head attention
1951
+ - Conditional generation support
1952
+
1953
+ Args:
1954
+ patch_size: Size of image patches
1955
+ in_channels: Number of input channels
1956
+ out_channels: Number of output channels (defaults to in_channels)
1957
+ hidden_size: Size of hidden layers
1958
+ num_layers: Number of transformer layers
1959
+ num_refiner_layers: Number of refiner layers
1960
+ num_attention_heads: Number of attention heads
1961
+ num_kv_heads: Number of key-value heads
1962
+ multiple_of: Multiple of which the hidden dimension should be
1963
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
1964
+ norm_eps: Epsilon value for normalization layers
1965
+ axes_dim_rope: Dimensions for rotary position embeddings
1966
+ axes_lens: Lengths for rotary position embeddings
1967
+ text_feat_dim: Dimension of text features
1968
+ timestep_scale: Scale factor for timestep embeddings
1969
+ use_fused_rms_norm: Whether to use fused RMS normalization
1970
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1971
+ """
1972
+
1973
+ _supports_gradient_checkpointing = True
1974
+ _no_split_modules = ["ThinkGenTransformerBlock"]
1975
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
1976
+
1977
+ @register_to_config
1978
+ def __init__(
1979
+ self,
1980
+ patch_size: int = 2,
1981
+ in_channels: int = 16,
1982
+ out_channels: Optional[int] = None,
1983
+ hidden_size: int = 2304,
1984
+ num_layers: int = 26,
1985
+ num_refiner_layers: int = 2,
1986
+ num_attention_heads: int = 24,
1987
+ num_kv_heads: int = 8,
1988
+ multiple_of: int = 256,
1989
+ ffn_dim_multiplier: Optional[float] = None,
1990
+ norm_eps: float = 1e-5,
1991
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
1992
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1993
+ text_feat_dim: int = 1024,
1994
+ timestep_scale: float = 1.0
1995
+ ) -> None:
1996
+ """Initialize the ThinkGen transformer model."""
1997
+ super().__init__()
1998
+
1999
+ # Validate configuration
2000
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
2001
+ raise ValueError(
2002
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
2003
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
2004
+ )
2005
+
2006
+ self.out_channels = out_channels or in_channels
2007
+
2008
+ # Initialize embeddings
2009
+ self.rope_embedder = ThinkGenRotaryPosEmbed(
2010
+ theta=10000,
2011
+ axes_dim=axes_dim_rope,
2012
+ axes_lens=axes_lens,
2013
+ patch_size=patch_size,
2014
+ )
2015
+
2016
+ self.x_embedder = nn.Linear(
2017
+ in_features=patch_size * patch_size * in_channels,
2018
+ out_features=hidden_size,
2019
+ )
2020
+
2021
+ self.ref_image_patch_embedder = nn.Linear(
2022
+ in_features=patch_size * patch_size * in_channels,
2023
+ out_features=hidden_size,
2024
+ )
2025
+
2026
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
2027
+ hidden_size=hidden_size,
2028
+ text_feat_dim=text_feat_dim,
2029
+ norm_eps=norm_eps,
2030
+ timestep_scale=timestep_scale
2031
+ )
2032
+
2033
+ # Initialize transformer blocks
2034
+ self.noise_refiner = nn.ModuleList([
2035
+ ThinkGenTransformerBlock(
2036
+ hidden_size,
2037
+ num_attention_heads,
2038
+ num_kv_heads,
2039
+ multiple_of,
2040
+ ffn_dim_multiplier,
2041
+ norm_eps,
2042
+ modulation=True
2043
+ )
2044
+ for _ in range(num_refiner_layers)
2045
+ ])
2046
+
2047
+ self.ref_image_refiner = nn.ModuleList([
2048
+ ThinkGenTransformerBlock(
2049
+ hidden_size,
2050
+ num_attention_heads,
2051
+ num_kv_heads,
2052
+ multiple_of,
2053
+ ffn_dim_multiplier,
2054
+ norm_eps,
2055
+ modulation=True
2056
+ )
2057
+ for _ in range(num_refiner_layers)
2058
+ ])
2059
+
2060
+ self.context_refiner = nn.ModuleList(
2061
+ [
2062
+ ThinkGenTransformerBlock(
2063
+ hidden_size,
2064
+ num_attention_heads,
2065
+ num_kv_heads,
2066
+ multiple_of,
2067
+ ffn_dim_multiplier,
2068
+ norm_eps,
2069
+ modulation=False
2070
+ )
2071
+ for _ in range(num_refiner_layers)
2072
+ ]
2073
+ )
2074
+
2075
+ # 3. Transformer blocks
2076
+ self.layers = nn.ModuleList(
2077
+ [
2078
+ ThinkGenTransformerBlock(
2079
+ hidden_size,
2080
+ num_attention_heads,
2081
+ num_kv_heads,
2082
+ multiple_of,
2083
+ ffn_dim_multiplier,
2084
+ norm_eps,
2085
+ modulation=True
2086
+ )
2087
+ for _ in range(num_layers)
2088
+ ]
2089
+ )
2090
+
2091
+ # 4. Output norm & projection
2092
+ self.norm_out = LuminaLayerNormContinuous(
2093
+ embedding_dim=hidden_size,
2094
+ conditioning_embedding_dim=min(hidden_size, 1024),
2095
+ elementwise_affine=False,
2096
+ eps=1e-6,
2097
+ bias=True,
2098
+ out_dim=patch_size * patch_size * self.out_channels
2099
+ )
2100
+
2101
+ # Add learnable embeddings to distinguish different images
2102
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
2103
+
2104
+ self.gradient_checkpointing = False
2105
+
2106
+ self.initialize_weights()
2107
+
2108
+ # TeaCache settings
2109
+ self.enable_teacache = False
2110
+ self.teacache_rel_l1_thresh = 0.05
2111
+ self.teacache_params = TeaCacheParams()
2112
+
2113
+ coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
2114
+ self.rescale_func = np.poly1d(coefficients)
2115
+
2116
+ self.prepad_embed = nn.Parameter(torch.randn(1, 23, 8192))
2117
+ print("add prepad_embed parameter ! ")
2118
+
2119
+ self.register_buffer('prepad_mask', torch.ones(1, 23).to(torch.int64))
2120
+
2121
+
2122
+ def initialize_weights(self) -> None:
2123
+ """
2124
+ Initialize the weights of the model.
2125
+
2126
+ Uses Xavier uniform initialization for linear layers.
2127
+ """
2128
+ nn.init.xavier_uniform_(self.x_embedder.weight)
2129
+ nn.init.constant_(self.x_embedder.bias, 0.0)
2130
+
2131
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
2132
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
2133
+
2134
+ nn.init.zeros_(self.norm_out.linear_1.weight)
2135
+ nn.init.zeros_(self.norm_out.linear_1.bias)
2136
+ nn.init.zeros_(self.norm_out.linear_2.weight)
2137
+ nn.init.zeros_(self.norm_out.linear_2.bias)
2138
+
2139
+ nn.init.normal_(self.image_index_embedding, std=0.02)
2140
+
2141
+ def img_patch_embed_and_refine(
2142
+ self,
2143
+ hidden_states,
2144
+ ref_image_hidden_states,
2145
+ padded_img_mask,
2146
+ padded_ref_img_mask,
2147
+ noise_rotary_emb,
2148
+ ref_img_rotary_emb,
2149
+ l_effective_ref_img_len,
2150
+ l_effective_img_len,
2151
+ temb
2152
+ ):
2153
+ batch_size = len(hidden_states)
2154
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
2155
+
2156
+ hidden_states = self.x_embedder(hidden_states)
2157
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
2158
+
2159
+ # 添加image_index_embedding
2160
+ for i in range(batch_size):
2161
+ shift = 0
2162
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
2163
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
2164
+ shift += ref_img_len
2165
+
2166
+ for layer in self.noise_refiner:
2167
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
2168
+
2169
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
2170
+ num_ref_images = len(flat_l_effective_ref_img_len)
2171
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
2172
+
2173
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
2174
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
2175
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
2176
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
2177
+
2178
+ # sequence of ref imgs to batch
2179
+ idx = 0
2180
+ for i in range(batch_size):
2181
+ shift = 0
2182
+ for ref_img_len in l_effective_ref_img_len[i]:
2183
+ batch_ref_img_mask[idx, :ref_img_len] = True
2184
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
2185
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
2186
+ batch_temb[idx] = temb[i]
2187
+ shift += ref_img_len
2188
+ idx += 1
2189
+
2190
+ # refine ref imgs separately
2191
+ for layer in self.ref_image_refiner:
2192
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
2193
+
2194
+ # batch of ref imgs to sequence
2195
+ idx = 0
2196
+ for i in range(batch_size):
2197
+ shift = 0
2198
+ for ref_img_len in l_effective_ref_img_len[i]:
2199
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
2200
+ shift += ref_img_len
2201
+ idx += 1
2202
+
2203
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
2204
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
2205
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
2206
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
2207
+
2208
+ return combined_img_hidden_states
2209
+
2210
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
2211
+ batch_size = len(hidden_states)
2212
+ p = self.config.patch_size
2213
+ device = hidden_states[0].device
2214
+
2215
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
2216
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
2217
+
2218
+ if ref_image_hidden_states is not None:
2219
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
2220
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
2221
+ else:
2222
+ ref_img_sizes = [None for _ in range(batch_size)]
2223
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
2224
+
2225
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
2226
+ max_img_len = max(l_effective_img_len)
2227
+
2228
+ # ref image patch embeddings
2229
+ flat_ref_img_hidden_states = []
2230
+ for i in range(batch_size):
2231
+ if ref_img_sizes[i] is not None:
2232
+ imgs = []
2233
+ for ref_img in ref_image_hidden_states[i]:
2234
+ C, H, W = ref_img.size()
2235
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
2236
+ imgs.append(ref_img)
2237
+
2238
+ img = torch.cat(imgs, dim=0)
2239
+ flat_ref_img_hidden_states.append(img)
2240
+ else:
2241
+ flat_ref_img_hidden_states.append(None)
2242
+
2243
+ # image patch embeddings
2244
+ flat_hidden_states = []
2245
+ for i in range(batch_size):
2246
+ img = hidden_states[i]
2247
+ C, H, W = img.size()
2248
+
2249
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
2250
+ flat_hidden_states.append(img)
2251
+
2252
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
2253
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
2254
+ for i in range(batch_size):
2255
+ if ref_img_sizes[i] is not None:
2256
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
2257
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
2258
+
2259
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
2260
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
2261
+ for i in range(batch_size):
2262
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
2263
+ padded_img_mask[i, :l_effective_img_len[i]] = True
2264
+
2265
+ return (
2266
+ padded_hidden_states,
2267
+ padded_ref_img_hidden_states,
2268
+ padded_img_mask,
2269
+ padded_ref_img_mask,
2270
+ l_effective_ref_img_len,
2271
+ l_effective_img_len,
2272
+ ref_img_sizes,
2273
+ img_sizes,
2274
+ )
2275
+
2276
+ def forward(
2277
+ self,
2278
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
2279
+ timestep: torch.Tensor,
2280
+ text_hidden_states: torch.Tensor,
2281
+ freqs_cis: torch.Tensor,
2282
+ text_attention_mask: torch.Tensor,
2283
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
2284
+ attention_kwargs: Optional[Dict[str, Any]] = None,
2285
+ return_dict: bool = False,
2286
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
2287
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
2288
+
2289
+ # if self.prepad_embed.dtype != text_hidden_states.dtype:
2290
+ # self.prepad_embed = self.prepad_embed.to(text_hidden_states.dtype)
2291
+ # if self.prepad_mask.device != text_attention_mask.device:
2292
+ # self.prepad_mask = self.prepad_mask.to(text_attention_mask.device)
2293
+
2294
+ bs = text_hidden_states.shape[0]
2295
+ prepad_embed = self.prepad_embed.repeat(bs, 1, 1)
2296
+ prepad_mask = self.prepad_mask.repeat(bs, 1)
2297
+ text_hidden_states = torch.cat([prepad_embed, text_hidden_states], dim = 1)
2298
+ text_attention_mask = torch.cat([prepad_mask, text_attention_mask], dim = 1)
2299
+
2300
+
2301
+ if enable_taylorseer:
2302
+ cal_type(self.cache_dic, self.current)
2303
+
2304
+ if attention_kwargs is not None:
2305
+ attention_kwargs = attention_kwargs.copy()
2306
+ lora_scale = attention_kwargs.pop("scale", 1.0)
2307
+ else:
2308
+ lora_scale = 1.0
2309
+
2310
+ if USE_PEFT_BACKEND:
2311
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
2312
+ scale_lora_layers(self, lora_scale)
2313
+ else:
2314
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
2315
+ logger.warning(
2316
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
2317
+ )
2318
+
2319
+ # 1. Condition, positional & patch embedding
2320
+ batch_size = len(hidden_states)
2321
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
2322
+
2323
+ if is_hidden_states_tensor:
2324
+ assert hidden_states.ndim == 4
2325
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
2326
+
2327
+ device = hidden_states[0].device
2328
+
2329
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
2330
+
2331
+ (
2332
+ hidden_states,
2333
+ ref_image_hidden_states,
2334
+ img_mask,
2335
+ ref_img_mask,
2336
+ l_effective_ref_img_len,
2337
+ l_effective_img_len,
2338
+ ref_img_sizes,
2339
+ img_sizes,
2340
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
2341
+
2342
+ (
2343
+ context_rotary_emb,
2344
+ ref_img_rotary_emb,
2345
+ noise_rotary_emb,
2346
+ rotary_emb,
2347
+ encoder_seq_lengths,
2348
+ seq_lengths,
2349
+ ) = self.rope_embedder(
2350
+ freqs_cis,
2351
+ text_attention_mask,
2352
+ l_effective_ref_img_len,
2353
+ l_effective_img_len,
2354
+ ref_img_sizes,
2355
+ img_sizes,
2356
+ device,
2357
+ )
2358
+
2359
+ # 2. Context refinement
2360
+ for layer in self.context_refiner:
2361
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
2362
+
2363
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
2364
+ hidden_states,
2365
+ ref_image_hidden_states,
2366
+ img_mask,
2367
+ ref_img_mask,
2368
+ noise_rotary_emb,
2369
+ ref_img_rotary_emb,
2370
+ l_effective_ref_img_len,
2371
+ l_effective_img_len,
2372
+ temb,
2373
+ )
2374
+
2375
+ # 3. Joint Transformer blocks (joint text embed 和 image embed)
2376
+ max_seq_len = max(seq_lengths)
2377
+
2378
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
2379
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
2380
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
2381
+ attention_mask[i, :seq_len] = True
2382
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
2383
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
2384
+
2385
+ hidden_states = joint_hidden_states
2386
+
2387
+ if self.enable_teacache:
2388
+ teacache_hidden_states = hidden_states.clone()
2389
+ teacache_temb = temb.clone()
2390
+ modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
2391
+ if self.teacache_params.is_first_or_last_step:
2392
+ should_calc = True
2393
+ self.teacache_params.accumulated_rel_l1_distance = 0
2394
+ else:
2395
+ self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
2396
+ ((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \
2397
+ / self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
2398
+ )
2399
+ if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
2400
+ should_calc = False
2401
+ else:
2402
+ should_calc = True
2403
+ self.teacache_params.accumulated_rel_l1_distance = 0
2404
+ self.teacache_params.previous_modulated_inp = modulated_inp
2405
+
2406
+ if self.enable_teacache:
2407
+ if not should_calc:
2408
+ hidden_states += self.teacache_params.previous_residual
2409
+ else:
2410
+ ori_hidden_states = hidden_states.clone()
2411
+ for layer_idx, layer in enumerate(self.layers):
2412
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2413
+ hidden_states = self._gradient_checkpointing_func(
2414
+ layer, hidden_states, attention_mask, rotary_emb, temb
2415
+ )
2416
+ else:
2417
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
2418
+ self.teacache_params.previous_residual = hidden_states - ori_hidden_states
2419
+ else:
2420
+ if enable_taylorseer:
2421
+ self.current['stream'] = 'layers_stream'
2422
+
2423
+ for layer_idx, layer in enumerate(self.layers):
2424
+ if enable_taylorseer:
2425
+ layer.current = self.current
2426
+ layer.cache_dic = self.cache_dic
2427
+ layer.enable_taylorseer = True
2428
+ self.current['layer'] = layer_idx
2429
+
2430
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2431
+ hidden_states = self._gradient_checkpointing_func(
2432
+ layer, hidden_states, attention_mask, rotary_emb, temb
2433
+ )
2434
+ else:
2435
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
2436
+
2437
+ # 4. Output norm & projection
2438
+ hidden_states = self.norm_out(hidden_states, temb)
2439
+
2440
+ p = self.config.patch_size
2441
+ output = []
2442
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
2443
+ height, width = img_size
2444
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
2445
+ if is_hidden_states_tensor:
2446
+ output = torch.stack(output, dim=0)
2447
+
2448
+ if USE_PEFT_BACKEND:
2449
+ # remove `lora_scale` from each PEFT layer
2450
+ unscale_lora_layers(self, lora_scale)
2451
+
2452
+ if enable_taylorseer:
2453
+ self.current['step'] += 1
2454
+
2455
+ if not return_dict:
2456
+ return output
2457
+ return Transformer2DModelOutput(sample=output)
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.34.1",
4
+ "_name_or_path": "/share_2/luoxin/modelscope/hub/models/FLUX.1-dev",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 0.3611,
29
+ "shift_factor": 0.1159,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": false,
37
+ "use_quant_conv": false
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c717328c8ad41faab2ccfd52ae17332505c6833cf176aad56e7b58f2c4d4c94
3
+ size 335306212