kelseye commited on
Commit
662278d
·
verified ·
1 Parent(s): 86f643a

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_Upscaler_1.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/cat_Upscaler_2.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/cat_lowers_512.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/food_Upscaler_1.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/food_Upscaler_2.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/food_lowers_100.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/food_lowers_512.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/girl_Upscaler_1.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/girl_Upscaler_2.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/girl_lowers_100.jpg filter=lfs diff=lfs merge=lfs -text
46
+ assets/girl_lowers_512.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates-Super-Resolution (FLUX.2-klein-base-4B)
5
+
6
+ This model is one of the Diffusion Templates series models in the open-source [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). Specifically designed for image super-resolution, it takes low-resolution input images and redraws them with rich high-definition details while preserving the original composition and semantics.
7
+
8
+ ## Results
9
+
10
+ > **Prompt:** A cat is sitting on a stone.
11
+
12
+ | Input (100px) | Output | Input (512px) | Output |
13
+ |:---:|:---:|:---:|:---:|
14
+ | ![](./assets/cat_lowers_100.jpg) | ![](./assets/cat_Upscaler_1.png) | ![](./assets/cat_lowers_512.jpg) | ![](./assets/cat_Upscaler_2.png) |
15
+
16
+ ---
17
+
18
+ > **Prompt:** An anime girl under a cherry blossom tree, looking at the sky.
19
+
20
+ | Input (100px) | Output | Input (512px) | Output |
21
+ |:---:|:---:|:---:|:---:|
22
+ | ![](./assets/girl_lowers_100.jpg) | ![](./assets/girl_Upscaler_1.png) | ![](./assets/girl_lowers_512.jpg) | ![](./assets/girl_Upscaler_2.png) |
23
+
24
+ ---
25
+
26
+ > **Prompt:** A hamburger with fries on a plate.
27
+
28
+ | Input (100px) | Output | Input (512px) | Output |
29
+ |:---:|:---:|:---:|:---:|
30
+ | ![](./assets/food_lowers_100.jpg) | ![](./assets/food_Upscaler_1.png) | ![](./assets/food_lowers_512.jpg) | ![](./assets/food_Upscaler_2.png) |
31
+
32
+ ## Inference Code
33
+
34
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
35
+
36
+ ```
37
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
38
+ cd DiffSynth-Studio
39
+ pip install -e .
40
+ ```
41
+
42
+ * Direct inference (requires 40G GPU memory)
43
+
44
+ ```python
45
+ from diffsynth.diffusion.template import TemplatePipeline
46
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
47
+ import torch
48
+ from modelscope import dataset_snapshot_download
49
+ from PIL import Image
50
+ ```
51
+
52
+ ```python
53
+ pipe = Flux2ImagePipeline.from_pretrained(
54
+ torch_dtype=torch.bfloat16,
55
+ device="cuda",
56
+ model_configs=[
57
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
58
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
59
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
60
+ ],
61
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
62
+ )
63
+ template = TemplatePipeline.from_pretrained(
64
+ torch_dtype=torch.bfloat16,
65
+ device="cuda",
66
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Upscaler")],
67
+ )
68
+ dataset_snapshot_download(
69
+ "DiffSynth-Studio/examples_in_diffsynth",
70
+ allow_file_pattern=["templates/*"],
71
+ local_dir="data/examples",
72
+ )
73
+ image = template(
74
+ pipe,
75
+ prompt="A cat is sitting on a stone.",
76
+ seed=0, cfg_scale=4, num_inference_steps=50,
77
+ template_inputs = [{
78
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
79
+ "prompt": "A cat is sitting on a stone.",
80
+ }],
81
+ negative_template_inputs = [{
82
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
83
+ "prompt": "",
84
+ }],
85
+ )
86
+ image.save("image_Upscaler_1.png")
87
+ image = template(
88
+ pipe,
89
+ prompt="A cat is sitting on a stone.",
90
+ seed=0, cfg_scale=4, num_inference_steps=50,
91
+ template_inputs = [{
92
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
93
+ "prompt": "A cat is sitting on a stone.",
94
+ }],
95
+ negative_template_inputs = [{
96
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
97
+ "prompt": "",
98
+ }],
99
+ )
100
+ image.save("image_Upscaler_2.png")
101
+ ```
102
+
103
+ * Enable lazy loading and memory management, requires 24G GPU memory
104
+
105
+ ```python
106
+ from diffsynth.diffusion.template import TemplatePipeline
107
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
108
+ import torch
109
+ from modelscope import dataset_snapshot_download
110
+ from PIL import Image
111
+ ```
112
+
113
+ vram_config = {
114
+ "offload_dtype": "disk",
115
+ "offload_device": "disk",
116
+ "onload_dtype": torch.float8_e4m3fn,
117
+ "onload_device": "cpu",
118
+ "preparing_dtype": torch.float8_e4m3fn,
119
+ "preparing_device": "cuda",
120
+ "computation_dtype": torch.bfloat16,
121
+ "computation_device": "cuda",
122
+ }
123
+ pipe = Flux2ImagePipeline.from_pretrained(
124
+ torch_dtype=torch.bfloat16,
125
+ device="cuda",
126
+ model_configs=[
127
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
128
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
129
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
130
+ ],
131
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
132
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
133
+ )
134
+ template = TemplatePipeline.from_pretrained(
135
+ torch_dtype=torch.bfloat16,
136
+ device="cuda",
137
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Upscaler")],
138
+ lazy_loading=True,
139
+ )
140
+ dataset_snapshot_download(
141
+ "DiffSynth-Studio/examples_in_diffsynth",
142
+ allow_file_pattern=["templates/*"],
143
+ local_dir="data/examples",
144
+ )
145
+ image = template(
146
+ pipe,
147
+ prompt="A cat is sitting on a stone.",
148
+ seed=0, cfg_scale=4, num_inference_steps=50,
149
+ template_inputs = [{
150
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
151
+ "prompt": "A cat is sitting on a stone.",
152
+ }],
153
+ negative_template_inputs = [{
154
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
155
+ "prompt": "",
156
+ }],
157
+ )
158
+ image.save("image_Upscaler_1.png")
159
+ image = template(
160
+ pipe,
161
+ prompt="A cat is sitting on a stone.",
162
+ seed=0, cfg_scale=4, num_inference_steps=50,
163
+ template_inputs = [{
164
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
165
+ "prompt": "A cat is sitting on a stone.",
166
+ }],
167
+ negative_template_inputs = [{
168
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
169
+ "prompt": "",
170
+ }],
171
+ )
172
+ image.save("image_Upscaler_2.png")
173
+
174
+ ```
175
+
176
+ ## Training Code
177
+
178
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
179
+
180
+ ```shell
181
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Upscaler/*" --local_dir ./data/diffsynth_example_dataset
182
+
183
+ accelerate launch examples/flux2/model_training/train.py \
184
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Upscaler \
185
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Upscaler/metadata.jsonl \
186
+ --extra_inputs "template_inputs" \
187
+ --max_pixels 1048576 \
188
+ --dataset_repeat 50 \
189
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
190
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Upscaler:" \
191
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
192
+ --learning_rate 1e-4 \
193
+ --num_epochs 2 \
194
+ --remove_prefix_in_ckpt "pipe.template_model." \
195
+ --output_path "./models/train/Template-KleinBase4B-Upscaler_full" \
196
+ --trainable_models "template_model" \
197
+ --use_gradient_checkpointing \
198
+ --find_unused_parameters
199
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-超分辨率(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型专为图像超分辨率设计,能够接收低分辨率的输入图像,并在保持原有构图和语义的基础上,重绘并补充丰富的高清细节。
13
+
14
+ ## 效果展示
15
+
16
+ > **Prompt:** A cat is sitting on a stone.
17
+
18
+ | Input (100px) | Output | Input (512px) | Output |
19
+ |:---:|:---:|:---:|:---:|
20
+ | ![](./assets/cat_lowers_100.jpg) | ![](./assets/cat_Upscaler_1.png) | ![](./assets/cat_lowers_512.jpg) | ![](./assets/cat_Upscaler_2.png) |
21
+
22
+ ---
23
+
24
+ > **Prompt:** An anime girl under a cherry blossom tree, looking at the sky.
25
+
26
+ | Input (100px) | Output | Input (512px) | Output |
27
+ |:---:|:---:|:---:|:---:|
28
+ | ![](./assets/girl_lowers_100.jpg) | ![](./assets/girl_Upscaler_1.png) | ![](./assets/girl_lowers_512.jpg) | ![](./assets/girl_Upscaler_2.png) |
29
+
30
+ ---
31
+
32
+ > **Prompt:** A hamburger with fries on a plate.
33
+
34
+ | Input (100px) | Output | Input (512px) | Output |
35
+ |:---:|:---:|:---:|:---:|
36
+ | ![](./assets/food_lowers_100.jpg) | ![](./assets/food_Upscaler_1.png) | ![](./assets/food_lowers_512.jpg) | ![](./assets/food_Upscaler_2.png) |
37
+
38
+ ## 推理代码
39
+
40
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
41
+
42
+ ```
43
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
44
+ cd DiffSynth-Studio
45
+ pip install -e .
46
+ ```
47
+
48
+ * 直接推理,需 40G 显存
49
+
50
+ ```python
51
+ from diffsynth.diffusion.template import TemplatePipeline
52
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
53
+ import torch
54
+ from modelscope import dataset_snapshot_download
55
+ from PIL import Image
56
+
57
+ pipe = Flux2ImagePipeline.from_pretrained(
58
+ torch_dtype=torch.bfloat16,
59
+ device="cuda",
60
+ model_configs=[
61
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
62
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
63
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
64
+ ],
65
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
66
+ )
67
+ template = TemplatePipeline.from_pretrained(
68
+ torch_dtype=torch.bfloat16,
69
+ device="cuda",
70
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Upscaler")],
71
+ )
72
+ dataset_snapshot_download(
73
+ "DiffSynth-Studio/examples_in_diffsynth",
74
+ allow_file_pattern=["templates/*"],
75
+ local_dir="data/examples",
76
+ )
77
+ image = template(
78
+ pipe,
79
+ prompt="A cat is sitting on a stone.",
80
+ seed=0, cfg_scale=4, num_inference_steps=50,
81
+ template_inputs = [{
82
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
83
+ "prompt": "A cat is sitting on a stone.",
84
+ }],
85
+ negative_template_inputs = [{
86
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
87
+ "prompt": "",
88
+ }],
89
+ )
90
+ image.save("image_Upscaler_1.png")
91
+ image = template(
92
+ pipe,
93
+ prompt="A cat is sitting on a stone.",
94
+ seed=0, cfg_scale=4, num_inference_steps=50,
95
+ template_inputs = [{
96
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
97
+ "prompt": "A cat is sitting on a stone.",
98
+ }],
99
+ negative_template_inputs = [{
100
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
101
+ "prompt": "",
102
+ }],
103
+ )
104
+ image.save("image_Upscaler_2.png")
105
+ ```
106
+
107
+ * 开启惰性加载和显存管理,需 24G 显存
108
+
109
+ ```python
110
+ from diffsynth.diffusion.template import TemplatePipeline
111
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
112
+ import torch
113
+ from modelscope import dataset_snapshot_download
114
+ from PIL import Image
115
+
116
+ vram_config = {
117
+ "offload_dtype": "disk",
118
+ "offload_device": "disk",
119
+ "onload_dtype": torch.float8_e4m3fn,
120
+ "onload_device": "cpu",
121
+ "preparing_dtype": torch.float8_e4m3fn,
122
+ "preparing_device": "cuda",
123
+ "computation_dtype": torch.bfloat16,
124
+ "computation_device": "cuda",
125
+ }
126
+ pipe = Flux2ImagePipeline.from_pretrained(
127
+ torch_dtype=torch.bfloat16,
128
+ device="cuda",
129
+ model_configs=[
130
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
131
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
132
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
133
+ ],
134
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
135
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
136
+ )
137
+ template = TemplatePipeline.from_pretrained(
138
+ torch_dtype=torch.bfloat16,
139
+ device="cuda",
140
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Upscaler")],
141
+ lazy_loading=True,
142
+ )
143
+ dataset_snapshot_download(
144
+ "DiffSynth-Studio/examples_in_diffsynth",
145
+ allow_file_pattern=["templates/*"],
146
+ local_dir="data/examples",
147
+ )
148
+ image = template(
149
+ pipe,
150
+ prompt="A cat is sitting on a stone.",
151
+ seed=0, cfg_scale=4, num_inference_steps=50,
152
+ template_inputs = [{
153
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
154
+ "prompt": "A cat is sitting on a stone.",
155
+ }],
156
+ negative_template_inputs = [{
157
+ "image": Image.open("data/examples/templates/image_lowres_512.jpg"),
158
+ "prompt": "",
159
+ }],
160
+ )
161
+ image.save("image_Upscaler_1.png")
162
+ image = template(
163
+ pipe,
164
+ prompt="A cat is sitting on a stone.",
165
+ seed=0, cfg_scale=4, num_inference_steps=50,
166
+ template_inputs = [{
167
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
168
+ "prompt": "A cat is sitting on a stone.",
169
+ }],
170
+ negative_template_inputs = [{
171
+ "image": Image.open("data/examples/templates/image_lowres_100.jpg"),
172
+ "prompt": "",
173
+ }],
174
+ )
175
+ image.save("image_Upscaler_2.png")
176
+
177
+ ```
178
+
179
+ ## 训练代码
180
+
181
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
182
+
183
+ ```shell
184
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Upscaler/*" --local_dir ./data/diffsynth_example_dataset
185
+
186
+ accelerate launch examples/flux2/model_training/train.py \
187
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Upscaler \
188
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Upscaler/metadata.jsonl \
189
+ --extra_inputs "template_inputs" \
190
+ --max_pixels 1048576 \
191
+ --dataset_repeat 50 \
192
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
193
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Upscaler:" \
194
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
195
+ --learning_rate 1e-4 \
196
+ --num_epochs 2 \
197
+ --remove_prefix_in_ckpt "pipe.template_model." \
198
+ --output_path "./models/train/Template-KleinBase4B-Upscaler_full" \
199
+ --trainable_models "template_model" \
200
+ --use_gradient_checkpointing \
201
+ --find_unused_parameters
202
+ ```
assets/cat_Upscaler_1.png ADDED

Git LFS Details

  • SHA256: 6fff077590cbfb281131d769805e439b03b3acd7b4b7f8998061fb18320f54bd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
assets/cat_Upscaler_2.png ADDED

Git LFS Details

  • SHA256: 61cdfbd59cbaba9a7b2440b16fe1cd7f682869e93dcafb2c6f98b99427ccd60c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
assets/cat_lowers_100.jpg ADDED
assets/cat_lowers_512.jpg ADDED

Git LFS Details

  • SHA256: 3b4d4d61bd9cd5da03273f31096e07985048f28c74ff1bab3b8b3ab32929cfbe
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
assets/food_Upscaler_1.png ADDED

Git LFS Details

  • SHA256: ceba403ba80989a22703ec61e5c58cdfdfbb286ecc37d02a7ae3cb699ef6b45d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
assets/food_Upscaler_2.png ADDED

Git LFS Details

  • SHA256: a8f3893f5af3c505ed8f0b1adee1ecd4e05fb519f085ddb465f7370a66dd58d0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
assets/food_lowers_100.jpg ADDED

Git LFS Details

  • SHA256: 2896030579104867d282ef58479926839f70a79e6ba4601e7ae3d396ad3dd973
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
assets/food_lowers_512.jpg ADDED

Git LFS Details

  • SHA256: 67f5d77f3b2483b563a74aad3e869a8f0518e1a8f7dccbed077e9e762ec30d06
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
assets/girl_Upscaler_1.png ADDED

Git LFS Details

  • SHA256: 65b0488b12812eb5e94b307c4afef81672f1816eddf42cd30e3c273b0d51169f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/girl_Upscaler_2.png ADDED

Git LFS Details

  • SHA256: 93cb0a3fb729380769e405ee8beb70d8f5fa08a1f6efe95589f3eb70118acbb3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
assets/girl_lowers_100.jpg ADDED

Git LFS Details

  • SHA256: 7f061be3cf70241ad9c5a85d5518824b9721c878c4487d2e6eadeccb672be7c5
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
assets/girl_lowers_512.jpg ADDED

Git LFS Details

  • SHA256: 13c72f151d6f750fc1da909c8e321f935f63ff009b279c150e22e694df670f1a
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ import torch, math
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from diffsynth.core.attention import attention_forward
6
+ from diffsynth.core.gradient import gradient_checkpoint_forward
7
+ from diffsynth.models.flux2_dit import apply_rotary_emb, Flux2PosEmbed
8
+ from diffsynth.models.general_modules import get_timestep_embedding
9
+
10
+
11
+ class AdaLayerNormContinuous(nn.Module):
12
+ def __init__(self, dim_in, dim_out, eps=1e-6):
13
+ super().__init__()
14
+ self.linear = nn.Linear(dim_in, dim_out * 2, bias=False)
15
+ self.norm = nn.LayerNorm(dim_in, eps=eps, elementwise_affine=False, bias=False)
16
+
17
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
18
+ scale, shift = self.linear(torch.nn.functional.silu(conditioning_embedding)).chunk(2, dim=1)
19
+ x = self.norm(x) * (1 + scale) + shift
20
+ return x
21
+
22
+
23
+ class Flux2FeedForward(nn.Module):
24
+ def __init__(self, dim):
25
+ super().__init__()
26
+ self.linear_in = nn.Linear(dim, dim*3*2, bias=False)
27
+ self.linear_out = nn.Linear(dim*3, dim, bias=False)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x1, x2 = self.linear_in(x).chunk(2, dim=-1)
31
+ x = torch.nn.functional.silu(x1) * x2
32
+ x = self.linear_out(x)
33
+ return x
34
+
35
+
36
+ class Flux2TransformerBlock(nn.Module):
37
+ def __init__(self, dim, num_heads, eps=1e-6):
38
+ super().__init__()
39
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
40
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
41
+
42
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
43
+ self.img_ff = Flux2FeedForward(dim)
44
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
45
+ self.txt_ff = Flux2FeedForward(dim)
46
+
47
+ self.num_heads = num_heads
48
+ self.img_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False)
49
+ self.img_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps)
50
+ self.img_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps)
51
+ self.img_to_out = torch.nn.Linear(dim, dim, bias=False)
52
+ self.txt_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False)
53
+ self.txt_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps)
54
+ self.txt_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps)
55
+ self.txt_to_out = torch.nn.Linear(dim, dim, bias=False)
56
+
57
+ def attention(self, img: torch.Tensor, txt: torch.Tensor, image_rotary_emb: torch.Tensor, **kwargs) -> torch.Tensor:
58
+ img_q, img_k, img_v = self.img_to_qkv(img).chunk(3, dim=-1)
59
+ txt_q, txt_k, txt_v = self.txt_to_qkv(txt).chunk(3, dim=-1)
60
+ img_q, img_k, img_v, txt_q, txt_k, txt_v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), (img_q, img_k, img_v, txt_q, txt_k, txt_v)))
61
+ img_q = self.img_norm_q(img_q)
62
+ img_k = self.img_norm_k(img_k)
63
+ txt_q = self.txt_norm_q(txt_q)
64
+ txt_k = self.txt_norm_k(txt_k)
65
+
66
+ q = torch.cat([txt_q, img_q], dim=1)
67
+ k = torch.cat([txt_k, img_k], dim=1)
68
+ v = torch.cat([txt_v, img_v], dim=1)
69
+ q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1)
70
+ k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1)
71
+
72
+ img = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)")
73
+ txt, img = img.split_with_sizes([txt.shape[1], img.shape[1] - txt.shape[1]], dim=1)
74
+ txt = self.txt_to_out(txt)
75
+ img = self.img_to_out(img)
76
+ return img, txt, (k, v)
77
+
78
+ def forward(self, img, txt, temb_mod_params_img, temb_mod_params_txt, image_rotary_emb):
79
+ (img_shift_msa, img_scale_msa, img_gate_msa), (img_shift_mlp, img_scale_mlp, img_gate_mlp) = temb_mod_params_img
80
+ (txt_shift_msa, txt_scale_msa, txt_gate_msa), (txt_shift_mlp, txt_scale_mlp, txt_gate_mlp) = temb_mod_params_txt
81
+
82
+ norm_img = (1 + img_scale_msa) * self.img_norm1(img) + img_shift_msa
83
+ norm_txt = (1 + txt_scale_msa) * self.txt_norm1(txt) + txt_shift_msa
84
+ img_attn_out, txt_attn_out, kv_cache = self.attention(norm_img, norm_txt, image_rotary_emb)
85
+
86
+ img = img + img_gate_msa * img_attn_out
87
+ norm_img = self.img_norm2(img) * (1 + img_scale_mlp) + img_shift_mlp
88
+ img = img + img_gate_mlp * self.img_ff(norm_img)
89
+
90
+ txt = txt + txt_gate_msa * txt_attn_out
91
+ norm_txt = self.txt_norm2(txt) * (1 + txt_scale_mlp) + txt_shift_mlp
92
+ txt = txt + txt_gate_mlp * self.txt_ff(norm_txt)
93
+ return txt, img, kv_cache
94
+
95
+
96
+ class Flux2SingleTransformerBlock(nn.Module):
97
+ def __init__(self, dim, num_heads, eps: float = 1e-6):
98
+ super().__init__()
99
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
100
+ self.dim = dim
101
+ self.num_heads = num_heads
102
+ self.norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True)
103
+ self.norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True)
104
+ self.to_qkv_mlp_proj = torch.nn.Linear(dim, dim * 3 + dim * 3 * 2, bias=False)
105
+ self.to_out = torch.nn.Linear(dim + dim * 3, dim, bias=False)
106
+
107
+ def attention(self, x: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
108
+ x = self.to_qkv_mlp_proj(x)
109
+ qkv, mlp_x = torch.split(x, [3 * self.dim, self.dim * 3 * 2], dim=-1)
110
+ q, k, v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), qkv.chunk(3, dim=-1)))
111
+
112
+ q = self.norm_q(q)
113
+ k = self.norm_k(k)
114
+ q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1)
115
+ k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1)
116
+ x = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)")
117
+
118
+ x1, x2 = mlp_x.chunk(2, dim=-1)
119
+ x = torch.cat([x, torch.nn.functional.silu(x1) * x2], dim=-1)
120
+ x = self.to_out(x)
121
+ return x, (k, v)
122
+
123
+ def forward(self, x, temb_mod_params, image_rotary_emb):
124
+ mod_shift, mod_scale, mod_gate = temb_mod_params
125
+ norm_x = (1 + mod_scale) * self.norm(x) + mod_shift
126
+ attn_output, kv_cache = self.attention(x=norm_x, image_rotary_emb=image_rotary_emb,)
127
+ x = x + mod_gate * attn_output
128
+ return x, kv_cache
129
+
130
+
131
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
132
+ def __init__(self, dim_in, dim_out):
133
+ super().__init__()
134
+ self.dim_in = dim_in
135
+ self.timestep_embedder = torch.nn.Sequential(nn.Linear(dim_in, dim_out, bias=False), nn.SiLU(), nn.Linear(dim_out, dim_out, bias=False))
136
+
137
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
138
+ timesteps_proj = get_timestep_embedding(timestep, self.dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
139
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype))
140
+ return timesteps_emb
141
+
142
+
143
+ class Flux2Modulation(nn.Module):
144
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
145
+ super().__init__()
146
+ self.mod_param_sets = mod_param_sets
147
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
148
+
149
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
150
+ mod = torch.nn.functional.silu(temb)
151
+ mod = self.linear(mod)
152
+ mod = mod.unsqueeze(1)
153
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
154
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
155
+
156
+
157
+ class Flux2DiTVariantModel(torch.nn.Module):
158
+ def __init__(
159
+ self,
160
+ patch_size: int = 1,
161
+ in_channels: int = 128,
162
+ out_channels: Optional[int] = None,
163
+ num_layers: int = 5,
164
+ num_single_layers: int = 20,
165
+ attention_head_dim: int = 128,
166
+ num_attention_heads: int = 24,
167
+ joint_attention_dim: int = 7680,
168
+ timestep_guidance_channels: int = 256,
169
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
170
+ rope_theta: int = 2000,
171
+ ):
172
+ super().__init__()
173
+ self.out_channels = out_channels or in_channels
174
+ self.inner_dim = num_attention_heads * attention_head_dim
175
+
176
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
177
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
178
+
179
+ # 2. Combined timestep + guidance embedding
180
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
181
+ dim_in=timestep_guidance_channels,
182
+ dim_out=self.inner_dim,
183
+ )
184
+
185
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
186
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
187
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
188
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
189
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
190
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
191
+
192
+ # 4. Input projections
193
+ self.img_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
194
+ self.txt_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
195
+
196
+ # 5. Double Stream Transformer Blocks
197
+ self.transformer_blocks = nn.ModuleList([Flux2TransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_layers)])
198
+
199
+ # 6. Single Stream Transformer Blocks
200
+ self.single_transformer_blocks = nn.ModuleList([Flux2SingleTransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_single_layers)])
201
+
202
+ # 7. Output layers
203
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim)
204
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
205
+
206
+ def prepare_static_parameters(self, img, txt):
207
+ timestep = torch.zeros((1,), dtype=txt.dtype, device=txt.device)
208
+ img_ids = []
209
+ for latent_id, latent in enumerate(img):
210
+ _, _, height, width = latent.shape
211
+ x_ids = torch.cartesian_prod(torch.tensor([(latent_id + 1) * 10]), torch.arange(height), torch.arange(width), torch.arange(1))
212
+ img_ids.append(x_ids)
213
+ img_ids = torch.cat(img_ids, dim=0).to(txt.device)
214
+ txt_ids = torch.cartesian_prod(torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(txt.shape[1])).to(txt.device)
215
+ return timestep, img_ids, txt_ids
216
+
217
+ def patchify(self, img):
218
+ img_ = []
219
+ for latent in img:
220
+ latent = rearrange(latent, "B C H W -> B (H W) C")
221
+ img_.append(latent)
222
+ img_ = torch.concat(img_, dim=1)
223
+ return img_
224
+
225
+ @torch.no_grad()
226
+ def process_inputs(
227
+ self,
228
+ pipe, image, prompt,
229
+ **kwargs
230
+ ):
231
+ images = image
232
+ if not isinstance(images, list):
233
+ images = [images]
234
+ pipe.load_models_to_device(["vae"])
235
+ kv_cache_input_latents = [pipe.vae.encode(pipe.preprocess_image(image)) for image in images]
236
+ prompt_emb_unit = [unit for unit in pipe.units if unit.__class__.__name__ == "Flux2Unit_Qwen3PromptEmbedder"][0]
237
+ kv_cache_prompt_emb = prompt_emb_unit.process(pipe, prompt)["prompt_embeds"]
238
+ pipe.load_models_to_device([])
239
+ return {
240
+ "kv_cache_input_latents": kv_cache_input_latents,
241
+ "kv_cache_prompt_emb": kv_cache_prompt_emb,
242
+ }
243
+
244
+ def forward(
245
+ self,
246
+ kv_cache_input_latents,
247
+ kv_cache_prompt_emb,
248
+ use_gradient_checkpointing=False,
249
+ use_gradient_checkpointing_offload=False,
250
+ **kwargs,
251
+ ):
252
+ img = kv_cache_input_latents
253
+ txt = kv_cache_prompt_emb
254
+ num_txt_tokens = txt.shape[1]
255
+
256
+ # 1. Calculate timestep embedding and modulation parameters
257
+ timestep, img_ids, txt_ids = self.prepare_static_parameters(img, txt)
258
+ img = self.patchify(img)
259
+
260
+ temb = self.time_guidance_embed(timestep)
261
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
262
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
263
+ single_stream_mod = self.single_stream_modulation(temb)[0]
264
+
265
+ # 2. Input projection for image (img) and conditioning text (txt)
266
+ img = self.img_embedder(img)
267
+ txt = self.txt_embedder(txt)
268
+
269
+ # 3. Calculate RoPE embeddings from image and text tokens
270
+ image_rotary_emb = self.pos_embed(img_ids)
271
+ text_rotary_emb = self.pos_embed(txt_ids)
272
+ concat_rotary_emb = (
273
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
274
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
275
+ )
276
+
277
+ # 4. Double Stream Transformer Blocks
278
+ kv_cache = {}
279
+ for block_id, block in enumerate(self.transformer_blocks):
280
+ txt, img, kv_cache_ = gradient_checkpoint_forward(
281
+ block,
282
+ use_gradient_checkpointing=use_gradient_checkpointing,
283
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
284
+ img=img,
285
+ txt=txt,
286
+ temb_mod_params_img=double_stream_mod_img,
287
+ temb_mod_params_txt=double_stream_mod_txt,
288
+ image_rotary_emb=concat_rotary_emb,
289
+ )
290
+ kv_cache[f"double_{block_id}"] = kv_cache_
291
+ # Concatenate text and image streams for single-block inference
292
+ img = torch.cat([txt, img], dim=1)
293
+
294
+ # 5. Single Stream Transformer Blocks
295
+ for block_id, block in enumerate(self.single_transformer_blocks):
296
+ img, kv_cache_ = gradient_checkpoint_forward(
297
+ block,
298
+ use_gradient_checkpointing=use_gradient_checkpointing,
299
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
300
+ x=img,
301
+ temb_mod_params=single_stream_mod,
302
+ image_rotary_emb=concat_rotary_emb,
303
+ )
304
+ kv_cache[f"single_{block_id}"] = kv_cache_
305
+ # # Remove text tokens from concatenated stream
306
+ # img = img[:, num_txt_tokens:, ...]
307
+
308
+ # # 6. Output layers
309
+ # img = self.norm_out(img, temb)
310
+ # output = self.proj_out(img)
311
+
312
+ return {"kv_cache": kv_cache}
313
+
314
+
315
+ class TrainDataProcessor:
316
+ def __init__(self):
317
+ from diffsynth.core import UnifiedDataset
318
+ self.image_oparator = UnifiedDataset.default_image_operator(
319
+ base_path="", # If your dataset contains relative paths, please specify the root path here.
320
+ max_pixels=1024*1024,
321
+ height_division_factor=16,
322
+ width_division_factor=16,
323
+ )
324
+
325
+ def __call__(self, image, prompt, **kwargs):
326
+ return {
327
+ "image": self.image_oparator(image),
328
+ "prompt": prompt,
329
+ }
330
+
331
+ TEMPLATE_MODEL = Flux2DiTVariantModel
332
+ TEMPLATE_MODEL_PATH = "model.safetensors"
333
+ TEMPLATE_DATA_PROCESSOR = TrainDataProcessor
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b32ea21f99d0e3c7443a513d97567a0100c12cab044d1ef6d57103115b66cc2
3
+ size 7751106808