Safetensors
kelseye commited on
Commit
8c6a02f
·
verified ·
1 Parent(s): 9406f44

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/image2_age_20.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/image2_age_50.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/image2_age_80.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates - Age Control (FLUX.2-klein-base-4B)
5
+
6
+ This model is one of the Diffusion Templates series models open-sourced by [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). It allows direct control over the age of the person in the generated image by inputting the `age` parameter.
7
+
8
+ ## Results
9
+
10
+ > **Prompt:** A portrait of a woman with black hair, wearing a suit.
11
+
12
+ | Age = 20 | Age = 50 | Age = 80 |
13
+ |:---:|:---:|:---:|
14
+ | ![](./assets/image1_age_20.jpg) | ![](./assets/image1_age_50.jpg) | ![](./assets/image1_age_80.jpg) |
15
+
16
+ ---
17
+
18
+ > **Prompt:** A portrait of a man, autumn park background, warm evening sunlight.
19
+
20
+ | Age = 20 | Age = 50 | Age = 80 |
21
+ |:---:|:---:|:---:|
22
+ | ![](./assets/image2_age_20.jpg) | ![](./assets/image2_age_50.jpg) | ![](./assets/image2_age_80.jpg) |
23
+
24
+ ---
25
+ A fashion portrait of an elegant woman wearing a red silk dress, high fashion photography, soft lighting. A modern minimalist living room with furniture.
26
+
27
+ | Age = 20 | Age = 50 | Age = 80 |
28
+ |:---:|:---:|:---:|
29
+ | ![](./assets/image3_age_20.jpg) | ![](./assets/image3_age_50.jpg) | ![](./assets/image3_age_80.jpg) |
30
+
31
+ ## Inference Code
32
+
33
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
34
+
35
+ ```
36
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
37
+ cd DiffSynth-Studio
38
+ pip install -e .
39
+ ```
40
+
41
+ * Direct inference (requires 40GB GPU memory)
42
+
43
+ ```python
44
+ from diffsynth.diffusion.template import TemplatePipeline
45
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
46
+ import torch
47
+ ```
48
+
49
+ pipe = Flux2ImagePipeline.from_pretrained(
50
+ torch_dtype=torch.bfloat16,
51
+ device="cuda",
52
+ model_configs=[
53
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
54
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
55
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
56
+ ],
57
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
58
+ )
59
+ template = TemplatePipeline.from_pretrained(
60
+ torch_dtype=torch.bfloat16,
61
+ device="cuda",
62
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Age")],
63
+ )
64
+ image = template(
65
+ pipe,
66
+ prompt="A portrait of a woman with black hair, wearing a suit.",
67
+ seed=0, cfg_scale=4, num_inference_steps=50,
68
+ template_inputs=[{"age": 20}],
69
+ negative_template_inputs=[{"age": 45}],
70
+ )
71
+ image.save(f"image_age_20.jpg")
72
+ image = template(
73
+ pipe,
74
+ prompt="A portrait of a woman with black hair, wearing a suit.",
75
+ seed=0, cfg_scale=4, num_inference_steps=50,
76
+ template_inputs=[{"age": 50}],
77
+ negative_template_inputs=[{"age": 45}],
78
+ )
79
+ image.save(f"image_age_50.jpg")
80
+ image = template(
81
+ pipe,
82
+ prompt="A portrait of a woman with black hair, wearing a suit.",
83
+ seed=0, cfg_scale=4, num_inference_steps=50,
84
+ template_inputs=[{"age": 80}],
85
+ negative_template_inputs=[{"age": 45}],
86
+ )
87
+ image.save(f"image_age_80.jpg")
88
+ ```
89
+
90
+ * Enable lazy loading and memory management, requires 24G GPU memory
91
+
92
+ ```python
93
+ from diffsynth.diffusion.template import TemplatePipeline
94
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
95
+ import torch
96
+
97
+ ```python
98
+ vram_config = {
99
+ "offload_dtype": "disk",
100
+ "offload_device": "disk",
101
+ "onload_dtype": torch.float8_e4m3fn,
102
+ "onload_device": "cpu",
103
+ "preparing_dtype": torch.float8_e4m3fn,
104
+ "preparing_device": "cuda",
105
+ "computation_dtype": torch.bfloat16,
106
+ "computation_device": "cuda",
107
+ }
108
+ pipe = Flux2ImagePipeline.from_pretrained(
109
+ torch_dtype=torch.bfloat16,
110
+ device="cuda",
111
+ model_configs=[
112
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
113
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
114
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
115
+ ],
116
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
117
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
118
+ )
119
+ template = TemplatePipeline.from_pretrained(
120
+ torch_dtype=torch.bfloat16,
121
+ device="cuda",
122
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Age")],
123
+ lazy_loading=True,
124
+ )
125
+ image = template(
126
+ pipe,
127
+ prompt="A portrait of a woman with black hair, wearing a suit.",
128
+ seed=0, cfg_scale=4, num_inference_steps=50,
129
+ template_inputs=[{"age": 20}],
130
+ negative_template_inputs=[{"age": 45}],
131
+ )
132
+ image.save(f"image_age_20.jpg")
133
+ image = template(
134
+ pipe,
135
+ prompt="A portrait of a woman with black hair, wearing a suit.",
136
+ seed=0, cfg_scale=4, num_inference_steps=50,
137
+ template_inputs=[{"age": 50}],
138
+ negative_template_inputs=[{"age": 45}],
139
+ )
140
+ image.save(f"image_age_50.jpg")
141
+ image = template(
142
+ pipe,
143
+ prompt="A portrait of a woman with black hair, wearing a suit.",
144
+ seed=0, cfg_scale=4, num_inference_steps=50,
145
+ template_inputs=[{"age": 80}],
146
+ negative_template_inputs=[{"age": 45}],
147
+ )
148
+ image.save(f"image_age_80.jpg")
149
+ ```
150
+
151
+ ## Training Code
152
+
153
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
154
+
155
+ ```shell
156
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Age/*" --local_dir ./data/diffsynth_example_dataset
157
+
158
+ accelerate launch examples/flux2/model_training/train.py \
159
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Age \
160
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Age/metadata.jsonl \
161
+ --extra_inputs "template_inputs" \
162
+ --max_pixels 1048576 \
163
+ --dataset_repeat 50 \
164
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
165
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Age:" \
166
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
167
+ --learning_rate 1e-4 \
168
+ --num_epochs 2 \
169
+ --remove_prefix_in_ckpt "pipe.template_model." \
170
+ --output_path "./models/train/Template-KleinBase4B-Age_full" \
171
+ --trainable_models "template_model" \
172
+ --use_gradient_checkpointing \
173
+ --find_unused_parameters
174
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-年龄控制(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型能够通过直接输入 `age` 参数,控制生成图像中人物的年龄。
13
+
14
+ ## 效果展示
15
+
16
+ > **Prompt:** A portrait of a woman with black hair, wearing a suit.
17
+
18
+ | Age = 20 | Age = 50 | Age = 80 |
19
+ |:---:|:---:|:---:|
20
+ | ![](./assets/image1_age_20.jpg) | ![](./assets/image1_age_50.jpg) | ![](./assets/image1_age_80.jpg) |
21
+
22
+ ---
23
+
24
+ > **Prompt:** A portrait of a man, autumn park background, warm evening sunlight.
25
+
26
+ | Age = 20 | Age = 50 | Age = 80 |
27
+ |:---:|:---:|:---:|
28
+ | ![](./assets/image2_age_20.jpg) | ![](./assets/image2_age_50.jpg) | ![](./assets/image2_age_80.jpg) |
29
+
30
+ ---
31
+ A fashion portrait of an elegant woman wearing a red silk dress, high fashion photography, soft lighting.A modern minimalist living room with furniture.
32
+
33
+ | Age = 20 | Age = 50 | Age = 80 |
34
+ |:---:|:---:|:---:|
35
+ | ![](./assets/image3_age_20.jpg) | ![](./assets/image3_age_50.jpg) | ![](./assets/image3_age_80.jpg) |
36
+
37
+ ## 推理代码
38
+
39
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
40
+
41
+ ```
42
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
43
+ cd DiffSynth-Studio
44
+ pip install -e .
45
+ ```
46
+
47
+ * 直接推理,需 40G 显存
48
+
49
+ ```python
50
+ from diffsynth.diffusion.template import TemplatePipeline
51
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
52
+ import torch
53
+
54
+ pipe = Flux2ImagePipeline.from_pretrained(
55
+ torch_dtype=torch.bfloat16,
56
+ device="cuda",
57
+ model_configs=[
58
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
59
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
60
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
61
+ ],
62
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
63
+ )
64
+ template = TemplatePipeline.from_pretrained(
65
+ torch_dtype=torch.bfloat16,
66
+ device="cuda",
67
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Age")],
68
+ )
69
+ image = template(
70
+ pipe,
71
+ prompt="A portrait of a woman with black hair, wearing a suit.",
72
+ seed=0, cfg_scale=4, num_inference_steps=50,
73
+ template_inputs=[{"age": 20}],
74
+ negative_template_inputs=[{"age": 45}],
75
+ )
76
+ image.save(f"image_age_20.jpg")
77
+ image = template(
78
+ pipe,
79
+ prompt="A portrait of a woman with black hair, wearing a suit.",
80
+ seed=0, cfg_scale=4, num_inference_steps=50,
81
+ template_inputs=[{"age": 50}],
82
+ negative_template_inputs=[{"age": 45}],
83
+ )
84
+ image.save(f"image_age_50.jpg")
85
+ image = template(
86
+ pipe,
87
+ prompt="A portrait of a woman with black hair, wearing a suit.",
88
+ seed=0, cfg_scale=4, num_inference_steps=50,
89
+ template_inputs=[{"age": 80}],
90
+ negative_template_inputs=[{"age": 45}],
91
+ )
92
+ image.save(f"image_age_80.jpg")
93
+ ```
94
+
95
+ * 开启惰性加载和显存管理,需 24G 显存
96
+
97
+ ```python
98
+ from diffsynth.diffusion.template import TemplatePipeline
99
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
100
+ import torch
101
+
102
+ vram_config = {
103
+ "offload_dtype": "disk",
104
+ "offload_device": "disk",
105
+ "onload_dtype": torch.float8_e4m3fn,
106
+ "onload_device": "cpu",
107
+ "preparing_dtype": torch.float8_e4m3fn,
108
+ "preparing_device": "cuda",
109
+ "computation_dtype": torch.bfloat16,
110
+ "computation_device": "cuda",
111
+ }
112
+ pipe = Flux2ImagePipeline.from_pretrained(
113
+ torch_dtype=torch.bfloat16,
114
+ device="cuda",
115
+ model_configs=[
116
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
117
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
118
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
119
+ ],
120
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
121
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
122
+ )
123
+ template = TemplatePipeline.from_pretrained(
124
+ torch_dtype=torch.bfloat16,
125
+ device="cuda",
126
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Age")],
127
+ lazy_loading=True,
128
+ )
129
+ image = template(
130
+ pipe,
131
+ prompt="A portrait of a woman with black hair, wearing a suit.",
132
+ seed=0, cfg_scale=4, num_inference_steps=50,
133
+ template_inputs=[{"age": 20}],
134
+ negative_template_inputs=[{"age": 45}],
135
+ )
136
+ image.save(f"image_age_20.jpg")
137
+ image = template(
138
+ pipe,
139
+ prompt="A portrait of a woman with black hair, wearing a suit.",
140
+ seed=0, cfg_scale=4, num_inference_steps=50,
141
+ template_inputs=[{"age": 50}],
142
+ negative_template_inputs=[{"age": 45}],
143
+ )
144
+ image.save(f"image_age_50.jpg")
145
+ image = template(
146
+ pipe,
147
+ prompt="A portrait of a woman with black hair, wearing a suit.",
148
+ seed=0, cfg_scale=4, num_inference_steps=50,
149
+ template_inputs=[{"age": 80}],
150
+ negative_template_inputs=[{"age": 45}],
151
+ )
152
+ image.save(f"image_age_80.jpg")
153
+ ```
154
+
155
+ ## 训练代码
156
+
157
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
158
+
159
+ ```shell
160
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Age/*" --local_dir ./data/diffsynth_example_dataset
161
+
162
+ accelerate launch examples/flux2/model_training/train.py \
163
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Age \
164
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Age/metadata.jsonl \
165
+ --extra_inputs "template_inputs" \
166
+ --max_pixels 1048576 \
167
+ --dataset_repeat 50 \
168
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
169
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Age:" \
170
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
171
+ --learning_rate 1e-4 \
172
+ --num_epochs 2 \
173
+ --remove_prefix_in_ckpt "pipe.template_model." \
174
+ --output_path "./models/train/Template-KleinBase4B-Age_full" \
175
+ --trainable_models "template_model" \
176
+ --use_gradient_checkpointing \
177
+ --find_unused_parameters
178
+ ```
assets/image1_age_20.jpg ADDED
assets/image1_age_50.jpg ADDED
assets/image1_age_80.jpg ADDED
assets/image2_age_20.jpg ADDED

Git LFS Details

  • SHA256: 6a621a605ef3d7d5f8c60853bc4bc6509f2914de2850c891301c38a8ea0f1fbc
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
assets/image2_age_50.jpg ADDED

Git LFS Details

  • SHA256: e252570c47c0f825a0c86bffa5d6ffaf1c65b2480221abcb04d5517fac91de2d
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
assets/image2_age_80.jpg ADDED

Git LFS Details

  • SHA256: a9e5173edf20bcbde7058729784cd0ef0715a3f8c9f40989be69cf16b62d4e14
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
assets/image3_age_20.jpg ADDED
assets/image3_age_50.jpg ADDED
assets/image3_age_80.jpg ADDED
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+
6
+ class SingleValueEncoder(torch.nn.Module):
7
+ def __init__(self, dim_in=256, dim_out=4096, length=32):
8
+ super().__init__()
9
+ self.length = length
10
+ self.prefer_value_embedder = torch.nn.Sequential(torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out))
11
+ self.positional_embedding = torch.nn.Parameter(torch.randn(self.length, dim_out))
12
+
13
+ def get_timestep_embedding(self, timesteps, embedding_dim, max_period=10000):
14
+ half_dim = embedding_dim // 2
15
+ exponent = -math.log(max_period) * torch.arange(0, half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
16
+ emb = timesteps[:, None].float() * torch.exp(exponent)[None, :]
17
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
18
+ return emb
19
+
20
+ def forward(self, value, dtype):
21
+ emb = self.get_timestep_embedding(value * 10, 256).to(dtype)
22
+ emb = self.prefer_value_embedder(emb).squeeze(0)
23
+ base_embeddings = emb.expand(self.length, -1)
24
+ positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
25
+ learned_embeddings = base_embeddings + positional_embedding
26
+ return learned_embeddings
27
+
28
+
29
+ class ValueFormatModel(torch.nn.Module):
30
+ def __init__(self, num_double_blocks=5, num_single_blocks=20, dim=3072, num_heads=24, length=512):
31
+ super().__init__()
32
+ self.block_names = [f"double_{i}" for i in range(num_double_blocks)] + [f"single_{i}" for i in range(num_single_blocks)]
33
+ self.proj_k = torch.nn.ModuleDict({block_name: SingleValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
34
+ self.proj_v = torch.nn.ModuleDict({block_name: SingleValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
35
+ self.num_heads = num_heads
36
+ self.length = length
37
+
38
+ @torch.no_grad()
39
+ def process_inputs(self, pipe, age, **kwargs):
40
+ return {"value": torch.Tensor([age]).to(dtype=pipe.torch_dtype, device=pipe.device)}
41
+
42
+ def forward(self, value, **kwargs):
43
+ kv_cache = {}
44
+ for block_name in self.block_names:
45
+ k = self.proj_k[block_name](value, value.dtype)
46
+ k = k.view(1, self.length, self.num_heads, -1)
47
+ v = self.proj_v[block_name](value, value.dtype)
48
+ v = v.view(1, self.length, self.num_heads, -1)
49
+ kv_cache[block_name] = (k, v)
50
+ return {"kv_cache": kv_cache}
51
+
52
+
53
+ class DataAnnotator:
54
+ def __call__(self, age, **kwargs):
55
+ return {"age": age}
56
+
57
+
58
+ TEMPLATE_MODEL = ValueFormatModel
59
+ TEMPLATE_MODEL_PATH = "model.safetensors"
60
+ TEMPLATE_DATA_PROCESSOR = DataAnnotator
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46f30866638a48ffe5cf3d1524ec74c66d6ab9414176ec1ff371a9b51cf17ddf
3
+ size 1180292000