kelseye commited on
Commit
300a347
·
verified ·
1 Parent(s): 3c3aef6

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_Brightness_dark.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/cat_Brightness_light.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/cat_Brightness_normal.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/yard_Brightness_dark.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/yard_Brightness_light.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/yard_Brightness_normal.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates - Brightness Adjustment (FLUX.2-klein-base-4B)
5
+
6
+ This model is part of the open-source Diffusion Templates series from [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). It is a brightness adjustment model capable of customizing and modifying the brightness of images.
7
+
8
+ ## Results
9
+
10
+ > **Prompt:** A cat is sitting on a stone.
11
+
12
+ | dark | normal | light |
13
+ |:---:|:---:|:---:|
14
+ | ![](./assets/cat_Brightness_dark.jpg) | ![](./assets/cat_Brightness_normal.jpg) | ![](./assets/cat_Brightness_light.jpg) |
15
+
16
+ ---
17
+
18
+ > **Prompt:** A graceful ballerina posing on a stage.
19
+
20
+ | dark | normal | light |
21
+ |:---:|:---:|:---:|
22
+ | ![](./assets/girl_Brightness_dark.jpg) | ![](./assets/girl_Brightness_normal.jpg) | ![](./assets/girl_Brightness_light.jpg) |
23
+
24
+ ---
25
+
26
+ > **Prompt:** A quiet courtyard with a small fountain in the center.
27
+
28
+ | dark | normal | light |
29
+ |:---:|:---:|:---:|
30
+ | ![](./assets/yard_Brightness_dark.jpg) | ![](./assets/yard_Brightness_normal.jpg) | ![](./assets/yard_Brightness_light.jpg) |
31
+
32
+ ## Inference Code
33
+
34
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
35
+
36
+ ```
37
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
38
+ cd DiffSynth-Studio
39
+ pip install -e .
40
+ ```
41
+
42
+ * Direct inference, requires 40G GPU memory
43
+
44
+ ```python
45
+ from diffsynth.diffusion.template import TemplatePipeline
46
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
47
+ import torch
48
+ ```
49
+
50
+ pipe = Flux2ImagePipeline.from_pretrained(
51
+ torch_dtype=torch.bfloat16,
52
+ device="cuda",
53
+ model_configs=[
54
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
55
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
56
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
57
+ ],
58
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
59
+ )
60
+ template = TemplatePipeline.from_pretrained(
61
+ torch_dtype=torch.bfloat16,
62
+ device="cuda",
63
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Brightness")],
64
+ )
65
+ image = template(
66
+ pipe,
67
+ prompt="A cat is sitting on a stone.",
68
+ seed=0, cfg_scale=4, num_inference_steps=50,
69
+ template_inputs = [{"scale": 0.7}],
70
+ negative_template_inputs = [{"scale": 0.5}]
71
+ )
72
+ image.save("image_Brightness_light.jpg")
73
+ image = template(
74
+ pipe,
75
+ prompt="A cat is sitting on a stone.",
76
+ seed=0, cfg_scale=4, num_inference_steps=50,
77
+ template_inputs = [{"scale": 0.5}],
78
+ negative_template_inputs = [{"scale": 0.5}]
79
+ )
80
+ image.save("image_Brightness_normal.jpg")
81
+ image = template(
82
+ pipe,
83
+ prompt="A cat is sitting on a stone.",
84
+ seed=0, cfg_scale=4, num_inference_steps=50,
85
+ template_inputs = [{"scale": 0.3}],
86
+ negative_template_inputs = [{"scale": 0.5}]
87
+ )
88
+ image.save("image_Brightness_dark.jpg")
89
+ ```
90
+
91
+ * Enable lazy loading and memory management, requires 24G GPU memory
92
+
93
+ ```python
94
+ from diffsynth.diffusion.template import TemplatePipeline
95
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
96
+ import torch
97
+ ```
98
+
99
+ ```python
100
+ vram_config = {
101
+ "offload_dtype": "disk",
102
+ "offload_device": "disk",
103
+ "onload_dtype": torch.float8_e4m3fn,
104
+ "onload_device": "cpu",
105
+ "preparing_dtype": torch.float8_e4m3fn,
106
+ "preparing_device": "cuda",
107
+ "computation_dtype": torch.bfloat16,
108
+ "computation_device": "cuda",
109
+ }
110
+ pipe = Flux2ImagePipeline.from_pretrained(
111
+ torch_dtype=torch.bfloat16,
112
+ device="cuda",
113
+ model_configs=[
114
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
115
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
116
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
117
+ ],
118
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
119
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
120
+ )
121
+ template = TemplatePipeline.from_pretrained(
122
+ torch_dtype=torch.bfloat16,
123
+ device="cuda",
124
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Brightness")],
125
+ lazy_loading=True,
126
+ )
127
+ image = template(
128
+ pipe,
129
+ prompt="A cat is sitting on a stone.",
130
+ seed=0, cfg_scale=4, num_inference_steps=50,
131
+ template_inputs = [{"scale": 0.7}],
132
+ negative_template_inputs = [{"scale": 0.5}]
133
+ )
134
+ image.save("image_Brightness_light.jpg")
135
+ image = template(
136
+ pipe,
137
+ prompt="A cat is sitting on a stone.",
138
+ seed=0, cfg_scale=4, num_inference_steps=50,
139
+ template_inputs = [{"scale": 0.5}],
140
+ negative_template_inputs = [{"scale": 0.5}]
141
+ )
142
+ image.save("image_Brightness_normal.jpg")
143
+ image = template(
144
+ pipe,
145
+ prompt="A cat is sitting on a stone.",
146
+ seed=0, cfg_scale=4, num_inference_steps=50,
147
+ template_inputs = [{"scale": 0.3}],
148
+ negative_template_inputs = [{"scale": 0.5}]
149
+ )
150
+ image.save("image_Brightness_dark.jpg")
151
+ ```
152
+
153
+ ## Training Code
154
+
155
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
156
+
157
+ ```shell
158
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Brightness/*" --local_dir ./data/diffsynth_example_dataset
159
+
160
+ accelerate launch examples/flux2/model_training/train.py \
161
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
162
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
163
+ --extra_inputs "template_inputs" \
164
+ --max_pixels 1048576 \
165
+ --dataset_repeat 50 \
166
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
167
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Brightness:" \
168
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
169
+ --learning_rate 1e-4 \
170
+ --num_epochs 2 \
171
+ --remove_prefix_in_ckpt "pipe.template_model." \
172
+ --output_path "./models/train/Template-KleinBase4B-Brightness_full" \
173
+ --trainable_models "template_model" \
174
+ --use_gradient_checkpointing \
175
+ --find_unused_parameters
176
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-亮度调节(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型为亮度调节模型,能够自定义修改图像的亮度。
13
+
14
+ ## 效果展示
15
+
16
+ > **Prompt:** A cat is sitting on a stone.
17
+
18
+ | dark | normal | light |
19
+ |:---:|:---:|:---:|
20
+ | ![](./assets/cat_Brightness_dark.jpg) | ![](./assets/cat_Brightness_normal.jpg) | ![](./assets/cat_Brightness_light.jpg) |
21
+
22
+ ---
23
+
24
+ > **Prompt:** A graceful ballerina posing on a stage.
25
+
26
+ | dark | normal | light |
27
+ |:---:|:---:|:---:|
28
+ | ![](./assets/girl_Brightness_dark.jpg) | ![](./assets/girl_Brightness_normal.jpg) | ![](./assets/girl_Brightness_light.jpg) |
29
+
30
+ ---
31
+
32
+ > **Prompt:** A quiet courtyard with a small fountain in the center.
33
+
34
+ | dark | normal | light |
35
+ |:---:|:---:|:---:|
36
+ | ![](./assets/yard_Brightness_dark.jpg) | ![](./assets/yard_Brightness_normal.jpg) | ![](./assets/yard_Brightness_light.jpg) |
37
+
38
+ ## 推理代码
39
+
40
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
41
+
42
+ ```
43
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
44
+ cd DiffSynth-Studio
45
+ pip install -e .
46
+ ```
47
+
48
+ * 直接推理,需 40G 显存
49
+
50
+ ```python
51
+ from diffsynth.diffusion.template import TemplatePipeline
52
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
53
+ import torch
54
+
55
+ pipe = Flux2ImagePipeline.from_pretrained(
56
+ torch_dtype=torch.bfloat16,
57
+ device="cuda",
58
+ model_configs=[
59
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
60
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
61
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
62
+ ],
63
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
64
+ )
65
+ template = TemplatePipeline.from_pretrained(
66
+ torch_dtype=torch.bfloat16,
67
+ device="cuda",
68
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Brightness")],
69
+ )
70
+ image = template(
71
+ pipe,
72
+ prompt="A cat is sitting on a stone.",
73
+ seed=0, cfg_scale=4, num_inference_steps=50,
74
+ template_inputs = [{"scale": 0.7}],
75
+ negative_template_inputs = [{"scale": 0.5}]
76
+ )
77
+ image.save("image_Brightness_light.jpg")
78
+ image = template(
79
+ pipe,
80
+ prompt="A cat is sitting on a stone.",
81
+ seed=0, cfg_scale=4, num_inference_steps=50,
82
+ template_inputs = [{"scale": 0.5}],
83
+ negative_template_inputs = [{"scale": 0.5}]
84
+ )
85
+ image.save("image_Brightness_normal.jpg")
86
+ image = template(
87
+ pipe,
88
+ prompt="A cat is sitting on a stone.",
89
+ seed=0, cfg_scale=4, num_inference_steps=50,
90
+ template_inputs = [{"scale": 0.3}],
91
+ negative_template_inputs = [{"scale": 0.5}]
92
+ )
93
+ image.save("image_Brightness_dark.jpg")
94
+ ```
95
+
96
+ * 开启惰性加载和显存管理,需 24G 显存
97
+
98
+ ```python
99
+ from diffsynth.diffusion.template import TemplatePipeline
100
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
101
+ import torch
102
+
103
+ vram_config = {
104
+ "offload_dtype": "disk",
105
+ "offload_device": "disk",
106
+ "onload_dtype": torch.float8_e4m3fn,
107
+ "onload_device": "cpu",
108
+ "preparing_dtype": torch.float8_e4m3fn,
109
+ "preparing_device": "cuda",
110
+ "computation_dtype": torch.bfloat16,
111
+ "computation_device": "cuda",
112
+ }
113
+ pipe = Flux2ImagePipeline.from_pretrained(
114
+ torch_dtype=torch.bfloat16,
115
+ device="cuda",
116
+ model_configs=[
117
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
118
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
119
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
120
+ ],
121
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
122
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
123
+ )
124
+ template = TemplatePipeline.from_pretrained(
125
+ torch_dtype=torch.bfloat16,
126
+ device="cuda",
127
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Brightness")],
128
+ lazy_loading=True,
129
+ )
130
+ image = template(
131
+ pipe,
132
+ prompt="A cat is sitting on a stone.",
133
+ seed=0, cfg_scale=4, num_inference_steps=50,
134
+ template_inputs = [{"scale": 0.7}],
135
+ negative_template_inputs = [{"scale": 0.5}]
136
+ )
137
+ image.save("image_Brightness_light.jpg")
138
+ image = template(
139
+ pipe,
140
+ prompt="A cat is sitting on a stone.",
141
+ seed=0, cfg_scale=4, num_inference_steps=50,
142
+ template_inputs = [{"scale": 0.5}],
143
+ negative_template_inputs = [{"scale": 0.5}]
144
+ )
145
+ image.save("image_Brightness_normal.jpg")
146
+ image = template(
147
+ pipe,
148
+ prompt="A cat is sitting on a stone.",
149
+ seed=0, cfg_scale=4, num_inference_steps=50,
150
+ template_inputs = [{"scale": 0.3}],
151
+ negative_template_inputs = [{"scale": 0.5}]
152
+ )
153
+ image.save("image_Brightness_dark.jpg")
154
+ ```
155
+
156
+ ## 训练代码
157
+
158
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
159
+
160
+ ```shell
161
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Brightness/*" --local_dir ./data/diffsynth_example_dataset
162
+
163
+ accelerate launch examples/flux2/model_training/train.py \
164
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
165
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
166
+ --extra_inputs "template_inputs" \
167
+ --max_pixels 1048576 \
168
+ --dataset_repeat 50 \
169
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
170
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Brightness:" \
171
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
172
+ --learning_rate 1e-4 \
173
+ --num_epochs 2 \
174
+ --remove_prefix_in_ckpt "pipe.template_model." \
175
+ --output_path "./models/train/Template-KleinBase4B-Brightness_full" \
176
+ --trainable_models "template_model" \
177
+ --use_gradient_checkpointing \
178
+ --find_unused_parameters
179
+ ```
assets/cat_Brightness_dark.jpg ADDED

Git LFS Details

  • SHA256: c506c7bad1af2057a1c31fe388ac4d4b7050a999211f65b3c12536a9c1b760a2
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
assets/cat_Brightness_light.jpg ADDED

Git LFS Details

  • SHA256: 11c3b78bddc56f1fbf8ffd0866ec54909382a4fba75acb03019af1cb0aaf5c31
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
assets/cat_Brightness_normal.jpg ADDED

Git LFS Details

  • SHA256: d2317ac22f5e4dd428b4a6930c490be342cb891e492c53e8e74e9a93945920e2
  • Pointer size: 131 Bytes
  • Size of remote file: 134 kB
assets/girl_Brightness_dark.jpg ADDED
assets/girl_Brightness_light.jpg ADDED
assets/girl_Brightness_normal.jpg ADDED
assets/yard_Brightness_dark.jpg ADDED

Git LFS Details

  • SHA256: 23aeb61287461b6f4c0d81795eee92c12e631d57fafce35bcdac1713e6fba951
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
assets/yard_Brightness_light.jpg ADDED

Git LFS Details

  • SHA256: 337285e700c426e72c4dc54593e840f6b66b14db1ae6b43cc16a3fe75b28ae39
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
assets/yard_Brightness_normal.jpg ADDED

Git LFS Details

  • SHA256: c0aa11c968221f20f633b46b474e9b08d8b561c18a22fd4df05f619820d39ca5
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+
6
+ class SingleValueEncoder(torch.nn.Module):
7
+ def __init__(self, dim_in=256, dim_out=4096, length=32):
8
+ super().__init__()
9
+ self.length = length
10
+ self.prefer_value_embedder = torch.nn.Sequential(torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out))
11
+ self.positional_embedding = torch.nn.Parameter(torch.randn(self.length, dim_out))
12
+
13
+ def get_timestep_embedding(self, timesteps, embedding_dim, max_period=10000):
14
+ half_dim = embedding_dim // 2
15
+ exponent = -math.log(max_period) * torch.arange(0, half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
16
+ emb = timesteps[:, None].float() * torch.exp(exponent)[None, :]
17
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
18
+ return emb
19
+
20
+ def forward(self, value, dtype):
21
+ emb = self.get_timestep_embedding(value * 1000, 256).to(dtype)
22
+ emb = self.prefer_value_embedder(emb).squeeze(0)
23
+ base_embeddings = emb.expand(self.length, -1)
24
+ positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
25
+ learned_embeddings = base_embeddings + positional_embedding
26
+ return learned_embeddings
27
+
28
+
29
+ class ValueFormatModel(torch.nn.Module):
30
+ def __init__(self, num_double_blocks=5, num_single_blocks=20, dim=3072, num_heads=24, length=512):
31
+ super().__init__()
32
+ self.block_names = [f"double_{i}" for i in range(num_double_blocks)] + [f"single_{i}" for i in range(num_single_blocks)]
33
+ self.proj_k = torch.nn.ModuleDict({block_name: SingleValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
34
+ self.proj_v = torch.nn.ModuleDict({block_name: SingleValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
35
+ self.num_heads = num_heads
36
+ self.length = length
37
+
38
+ @torch.no_grad()
39
+ def process_inputs(self, pipe, scale, **kwargs):
40
+ return {"value": torch.Tensor([scale]).to(dtype=pipe.torch_dtype, device=pipe.device)}
41
+
42
+ def forward(self, value, **kwargs):
43
+ kv_cache = {}
44
+ for block_name in self.block_names:
45
+ k = self.proj_k[block_name](value, value.dtype)
46
+ k = k.view(1, self.length, self.num_heads, -1)
47
+ v = self.proj_v[block_name](value, value.dtype)
48
+ v = v.view(1, self.length, self.num_heads, -1)
49
+ kv_cache[block_name] = (k, v)
50
+ return {"kv_cache": kv_cache}
51
+
52
+
53
+ class DataAnnotator:
54
+ def __call__(self, image, **kwargs):
55
+ image = Image.open(image)
56
+ image = np.array(image)
57
+ return {"scale": image.astype(np.float32).mean() / 255}
58
+
59
+
60
+ TEMPLATE_MODEL = ValueFormatModel
61
+ TEMPLATE_MODEL_PATH = "model.safetensors"
62
+ TEMPLATE_DATA_PROCESSOR = DataAnnotator
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04c33313c04e6664dbfd2a1b9a2a69a5da780634916e0a6ed93992e200be6ec1
3
+ size 1180292000