YongganFu pmolchanov commited on
Commit
c6706ba
·
0 Parent(s):

Initial release of Nemotron-Labs-Diffusion-VLM-8B

Browse files

Co-authored-by: pmolchanov <pmolchanov@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
39
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
40
+ assets/result_acc.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/result_efficiency.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: other
4
+ license_name: nscl-v1
5
+ pipeline_tag: image-text-to-text
6
+ tags:
7
+ - nvidia
8
+ - pytorch
9
+ - multimodal
10
+ - vlm
11
+ - diffusion-language-model
12
+ ---
13
+
14
+ # Nemotron-Labs-Diffusion-VLM-8B
15
+
16
+
17
+ <div align="center" style="line-height: 1;">
18
+ <a href="https://d1qx31qr3h6wln.cloudfront.net/publications/Nemotron_Diffusion_Tech_Report_v1.pdf?VersionId=db8_EMO8B.vmU26.jr7Le9pN3MqcUDNL" target="_blank" style="margin: 2px;">
19
+ <img alt="Chat" src="https://img.shields.io/badge/📝Paper-Read Now!-536af5?color=76B900&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
20
+ </a>
21
+ <a href="https://huggingface.co/collections/nvidia/nemotron-labs-diffusion" target="_blank" style="margin: 2px;">
22
+ <img alt="Nemotron-Labs-Diffusion Model Family" src="https://img.shields.io/badge/%F0%9F%A4%97-Nemotron--Labs--Diffusion_Model_Family-76B900" style="display: inline-block; vertical-align: middle;"/>
23
+ </a>
24
+ <a href="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-source-code-license/" style="margin: 2px;">
25
+ <img alt="License" src="https://img.shields.io/badge/License-NSCLv1-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
26
+ </a>
27
+ </div>
28
+
29
+
30
+ [![Demo](./assets/demo.gif)](./assets/demo.mp4)
31
+
32
+
33
+ ## Model Overview
34
+
35
+ Nemotron-Labs-Diffusion-VLM-8B is the vision-language extension of the Nemotron-Labs-Diffusion family. It pairs the same tri-mode language backbone (AR / diffusion / self-speculation, switchable by attention pattern) with a vision encoder, accepting interleaved image + text input and producing text output. The diffusion-based parallel decoding from the LM family carries over to VLM: the language head can draft a block in parallel and verify autoregressively against shared KV cache, retaining the family's decode-efficiency story while extending it to multimodal prompts.
36
+
37
+ <div align="center">
38
+ <img src="./assets/teaser.png" alt="An illustration of Tri-Mode LMs" width="500">
39
+ </div>
40
+
41
+
42
+ ## Key Design
43
+
44
+ - 8B vision-language model in the Nemotron-Labs-Diffusion family — same tri-mode language backbone (AR, diffusion, self-speculation) plus a Pixtral-style vision encoder.
45
+ - Vision encoder: 24-layer, 1024-hidden, 14×14 patch, supports up to 1540×1540 images with `spatial_merge_size=2`.
46
+ - Language decoder weights match `nvidia/Nemotron-Labs-Diffusion-8B` (34 layers, 4096 hidden, 14336 intermediate); the model card structure and inference modes inherit from the LM line.
47
+ - Diffusion-based parallel decoding works for multimodal prompts: image tokens are placed in the bidirectional context window and text generation proceeds via the same block-wise unmasking + AR verification as the LM family.
48
+
49
+
50
+ ## License/Terms of Use
51
+
52
+ Use of this model is governed by the **NVIDIA Source Code License (NSCLv1)**.
53
+
54
+
55
+ ## Environment
56
+
57
+ ```bash
58
+ transformers>=5.0.0
59
+ pillow
60
+ requests
61
+ opencv-python
62
+ ```
63
+
64
+
65
+ ## Chat with Our Model
66
+
67
+ ```python
68
+ import sys
69
+ import torch
70
+ from huggingface_hub import snapshot_download
71
+ from transformers import AutoModel, AutoTokenizer
72
+
73
+ repo_name = "nvidia/Nemotron-Labs-Diffusion-VLM-8B"
74
+ sys.path.insert(0, snapshot_download(repo_name))
75
+ from image_processing import process_messages
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
78
+ model = AutoModel.from_pretrained(repo_name, trust_remote_code=True).cuda().to(torch.bfloat16)
79
+
80
+ image_path = "path/to/your/image.jpg" # local file or http(s):// URL
81
+ messages = [{
82
+ "role": "user",
83
+ "content": [
84
+ {"type": "image_url", "image_url": {"url": image_path}},
85
+ {"type": "text", "text": "Describe this image."},
86
+ ],
87
+ }]
88
+
89
+ batch = process_messages(tokenizer, messages, add_generation_prompt=True)
90
+ prompt_ids = batch["input_ids"].to("cuda")
91
+ pixel_values = batch["pixel_values"].to("cuda", dtype=torch.bfloat16)
92
+
93
+ out_ids, nfe = model.generate(
94
+ prompt_ids,
95
+ pixel_values=pixel_values,
96
+ image_sizes=batch["image_sizes"],
97
+ max_new_tokens=512, steps=512, block_length=32,
98
+ shift_logits=False, threshold=0.9,
99
+ eos_token_id=tokenizer.eos_token_id,
100
+ )
101
+
102
+ tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)
103
+ print(f"Model: {tokenized_out[0]}")
104
+ print(f"[Num Function Eval (NFE)={nfe}]")
105
+ ```
106
+
107
+
108
+ ## Ethical Considerations
109
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the [bias](./model_cards/bias.md), [explainability](./model_cards/explainability.md), [safety & security](./model_cards/safety.md), and [privacy](./model_cards/privacy.md) subcards.
110
+
111
+ Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
112
+
113
+
114
+ ## Citations
115
+
116
+ ```bibtex
117
+ @techreport{fu2026nemotronlabsdiffusion,
118
+ title = {Nemotron-Labs-Diffusion: A Tri-Mode Language Model Unifying Autoregressive, Diffusion, and Self-Speculation Decoding},
119
+ author = {Yonggan Fu and Lexington Whalen and Abhinav Garg and Chengyue Wu and Maksim Khadkevich and Nicolai Oswald and Enze Xie and Daniel Egert and Sharath Turuvekere Sreenivas and Shizhe Diao and Chenhan Yu and Ye Yu and Weijia Chen and Sajad Norouzi and Shiyi Lan and Ligeng Zhu and Jin Wang and Jindong Jiang and Morteza Mardani and Mehran Maghoumi and Song Han and Ante Jukic and Nima Tajbakhsh and Jan Kautz and Pavlo Molchanov},
120
+ institution = {NVIDIA},
121
+ year = {2026},
122
+ note = {Technical report}
123
+ }
124
+ ```
assets/demo.gif ADDED

Git LFS Details

  • SHA256: 0d09264e272ac0f82dee36417f6a16511287ec1f8dee3b5dba3da222d791fd2c
  • Pointer size: 132 Bytes
  • Size of remote file: 8.25 MB
assets/demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:666d8785ac4af75931d9c677757c4ef9945bf114d07f1c4e2ebb7b893ac39006
3
+ size 9454873
assets/result_acc.png ADDED

Git LFS Details

  • SHA256: 992aa22ca9eca3d0bddbcd9f49837e2a9f377bbc0f7545563b129a50b3811448
  • Pointer size: 131 Bytes
  • Size of remote file: 405 kB
assets/result_efficiency.png ADDED

Git LFS Details

  • SHA256: 4f6161912e2aa703e0ef1bdccbb85039529b97e759d6247c33afa2a209806ede
  • Pointer size: 131 Bytes
  • Size of remote file: 801 kB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 6c94aa7b0c6cf8fb739724d0c1ce45749c76443c592eeab94d7cbb9083c6c6b1
  • Pointer size: 131 Bytes
  • Size of remote file: 581 kB
chat_template.jinja ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_extra_keys(json_dict, handled_keys) %}
2
+ {%- if json_dict is mapping %}
3
+ {%- for json_key in json_dict if json_key not in handled_keys %}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
6
+ {%- else %}
7
+ {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
8
+ {%- endif %}
9
+ {%- endfor %}
10
+ {%- endif %}
11
+ {% endmacro %}
12
+ {%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
13
+ {%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
14
+
15
+ {%- set ns = namespace(last_user_idx = -1) %}
16
+ {%- set loop_messages = messages %}
17
+ {%- for m in loop_messages %}
18
+ {%- if m["role"] == "user" %}
19
+ {%- set ns.last_user_idx = loop.index0 %}
20
+ {%- endif %}
21
+ {%- endfor %}
22
+
23
+ {%- if messages[0]["role"] == "system" %}
24
+ {%- set system_message = messages[0]["content"] %}
25
+ {%- set loop_messages = messages[1:] %}
26
+ {%- else %}
27
+ {%- set system_message = "" %}
28
+ {%- set loop_messages = messages %}
29
+ {%- endif %}
30
+ {%- if not tools is defined %}
31
+ {%- set tools = [] %}
32
+ {%- endif %}
33
+ {# Recompute last_user_idx relative to loop_messages after handling system #}
34
+ {%- set ns = namespace(last_user_idx = -1) %}
35
+ {%- for m in loop_messages %}
36
+ {%- if m["role"] == "user" %}
37
+ {%- set ns.last_user_idx = loop.index0 %}
38
+ {%- endif %}
39
+ {%- endfor %}
40
+ {%- if system_message is defined %}
41
+ {{- "<|im_start|>system\n" + system_message }}
42
+ {%- else %}
43
+ {%- if tools is iterable and tools | length > 0 %}
44
+ {{- "<|im_start|>system\n" }}
45
+ {%- endif %}
46
+ {%- endif %}
47
+ {%- if tools is iterable and tools | length > 0 %}
48
+ {%- if system_message is defined and system_message | length > 0 %}
49
+ {{- "\n\n" }}
50
+ {%- endif %}
51
+ {{- "# Tools\n\nYou have access to the following functions:\n\n" }}
52
+ {{- "<tools>" }}
53
+ {%- for tool in tools %}
54
+ {%- if tool.function is defined %}
55
+ {%- set tool = tool.function %}
56
+ {%- endif %}
57
+ {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
58
+ {%- if tool.description is defined %}
59
+ {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
60
+ {%- endif %}
61
+ {{- '\n<parameters>' }}
62
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
63
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
64
+ {{- '\n<parameter>' }}
65
+ {{- '\n<name>' ~ param_name ~ '</name>' }}
66
+ {%- if param_fields.type is defined %}
67
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
68
+ {%- endif %}
69
+ {%- if param_fields.description is defined %}
70
+ {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
71
+ {%- endif %}
72
+ {%- if param_fields.enum is defined %}
73
+ {{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
74
+ {%- endif %}
75
+ {%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
76
+ {{- render_extra_keys(param_fields, handled_keys) }}
77
+ {{- '\n</parameter>' }}
78
+ {%- endfor %}
79
+ {%- endif %}
80
+ {% set handled_keys = ['type', 'properties', 'required'] %}
81
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
82
+ {%- if tool.parameters is defined and tool.parameters.required is defined %}
83
+ {{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
84
+ {%- endif %}
85
+ {{- '\n</parameters>' }}
86
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
87
+ {{- render_extra_keys(tool, handled_keys) }}
88
+ {{- '\n</function>' }}
89
+ {%- endfor %}
90
+ {{- "\n</tools>" }}
91
+
92
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
93
+ {%- endif %}
94
+
95
+
96
+ {%- if system_message is defined %}
97
+ {{- '<|im_end|>\n' }}
98
+ {%- else %}
99
+ {%- if tools is iterable and tools | length > 0 %}
100
+ {{- '<|im_end|>\n' }}
101
+ {%- endif %}
102
+ {%- endif %}
103
+
104
+ {%- for message in loop_messages %}
105
+ {%- if message.role == "assistant" %}
106
+ {# Add reasoning content in to content field for unified processing below. #}
107
+ {%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
108
+ {%- set msg_content = message.content | default('', true) %}
109
+ {%- if msg_content is not string -%}
110
+ {%- set ns_rc = namespace(text='') -%}
111
+ {%- for block in msg_content -%}
112
+ {%- if block.type is defined and block.type == "text" -%}
113
+ {%- set ns_rc.text = ns_rc.text + block.text -%}
114
+ {%- endif -%}
115
+ {%- endfor -%}
116
+ {%- set msg_content = ns_rc.text -%}
117
+ {%- endif -%}
118
+ {%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ msg_content %}
119
+ {%- else %}
120
+ {%- set content = message.content | default('', true) %}
121
+ {%- if content is not string -%}
122
+ {%- set ns_c = namespace(text='') -%}
123
+ {%- for block in content -%}
124
+ {%- if block.type is defined and block.type == "text" -%}
125
+ {%- set ns_c.text = ns_c.text + block.text -%}
126
+ {%- endif -%}
127
+ {%- endfor -%}
128
+ {%- set content = ns_c.text -%}
129
+ {%- endif -%}
130
+ {%- if '<think>' not in content and '</think>' not in content -%}
131
+ {%- set content = "<think></think>" ~ content -%}
132
+ {%- endif -%}
133
+ {%- endif %}
134
+ {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
135
+ {# Assistant message has tool calls. #}
136
+ {{- '<|im_start|>assistant\n' }}
137
+ {%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
138
+ {%- if content is string and content | trim | length > 0 %}
139
+ {%- if include_content %}
140
+ {{- (content | trim) ~ '\n' -}}
141
+ {%- else %}
142
+ {%- set c = (content | string) %}
143
+ {%- if '</think>' in c %}
144
+ {# Keep only content after the last closing think. Also generation prompt causes this. #}
145
+ {%- set c = c.split('</think>')[-1] %}
146
+ {%- elif '<think>' in c %}
147
+ {# If <think> was opened but never closed, drop the trailing think segment #}
148
+ {%- set c = c.split('<think>')[0] %}
149
+ {%- endif %}
150
+ {%- set c = "<think></think>" ~ c | trim %}
151
+ {%- if c | length > 0 %}
152
+ {{- c ~ '\n' -}}
153
+ {%- endif %}
154
+ {%- endif %}
155
+ {%- else %}
156
+ {{- "<think></think>" -}}
157
+ {%- endif %}
158
+ {%- for tool_call in message.tool_calls %}
159
+ {%- if tool_call.function is defined %}
160
+ {%- set tool_call = tool_call.function %}
161
+ {%- endif %}
162
+ {{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
163
+ {%- if tool_call.arguments is defined %}
164
+ {%- for args_name, args_value in tool_call.arguments|items %}
165
+ {{- '<parameter=' ~ args_name ~ '>\n' -}}
166
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
167
+ {{- args_value ~ '\n</parameter>\n' -}}
168
+ {%- endfor %}
169
+ {%- endif %}
170
+ {{- '</function>\n</tool_call>\n' -}}
171
+ {%- endfor %}
172
+ {{- '<|im_end|>\n' }}
173
+ {%- else %}
174
+ {# Assistant message doesn't have tool calls. #}
175
+ {%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
176
+ {{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
177
+ {%- else %}
178
+ {%- set c = (content | default('', true) | string) %}
179
+ {%- if '<think>' in c and '</think>' in c %}
180
+ {%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
181
+ {%- endif %}
182
+ {%- set c = c | trim %}
183
+ {%- if c | length > 0 %}
184
+ {{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
185
+ {%- else %}
186
+ {{- '<|im_start|>assistant\n<|im_end|>\n' }}
187
+ {%- endif %}
188
+ {%- endif %}
189
+ {%- endif %}
190
+ {%- elif message.role == "user" or message.role == "system" %}
191
+ {{- '<|im_start|>' + message.role + '\n' }}
192
+ {%- if message.content is string %}
193
+ {{- message.content }}
194
+ {%- else %}
195
+ {%- for block in message.content %}
196
+ {%- if block.type == "text" %}
197
+ {{- block.text }}
198
+ {%- elif block.type in ["image", "image_url"] %}
199
+ {{- '<|image_start|>' }}
200
+ {%- endif %}
201
+ {%- endfor %}
202
+ {%- endif %}
203
+ {{- '<|im_end|>\n' }}
204
+ {%- elif message.role == "tool" %}
205
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
206
+ {{- '<|im_start|>user\n' }}
207
+ {%- endif %}
208
+ {{- '<tool_response>\n' }}
209
+ {{- message.content }}
210
+ {{- '\n</tool_response>\n' }}
211
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
212
+ {{- '<|im_end|>\n' }}
213
+ {%- elif loop.last %}
214
+ {{- '<|im_end|>\n' }}
215
+ {%- endif %}
216
+ {%- else %}
217
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
218
+ {%- endif %}
219
+ {%- endfor %}
220
+
221
+ {%- if add_generation_prompt %}
222
+ {%- if enable_thinking %}
223
+ {{- '<|im_start|>assistant\n<think>\n' }}
224
+ {%- else %}
225
+ {{- '<|im_start|>assistant\n<think></think>' }}
226
+ {%- endif %}
227
+ {%- endif %}
chat_utils.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ '''
8
+ The Gumbel max is a method for sampling categorical distributions.
9
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
10
+ Thus, we use float64.
11
+ '''
12
+ if temperature == 0:
13
+ return logits
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
21
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
22
+ x0 = torch.argmax(logits_with_noise, dim=-1)
23
+
24
+ if remasking == 'low_confidence':
25
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
26
+ p = F.softmax(logits, dim=-1)
27
+ x0_p = torch.squeeze(
28
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
29
+ elif remasking == 'top_p_margin':
30
+ # Compute probabilities
31
+ p = F.softmax(logits, dim=-1) # (B, L, V)
32
+ # Top-2 per position
33
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
34
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
35
+
36
+ # Normalize margin to [0,1] over MASKED positions per row
37
+ plus_inf = torch.full_like(margin, float('inf'))
38
+ minus_inf = torch.full_like(margin, float('-inf'))
39
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
40
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
41
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
42
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
43
+ denom = (row_max - row_min)
44
+
45
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
46
+ normalized = torch.zeros_like(margin)
47
+ nonzero = denom > 0
48
+ normalized = torch.where(
49
+ mask_index & nonzero,
50
+ (margin - row_min) / (denom + 1e-12),
51
+ normalized
52
+ )
53
+ normalized = torch.where(
54
+ mask_index & (~nonzero),
55
+ torch.ones_like(normalized),
56
+ normalized
57
+ )
58
+ x0_p = normalized # ∈ [0,1] on masked positions
59
+ elif remasking == 'random':
60
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
61
+ else:
62
+ raise NotImplementedError(remasking)
63
+
64
+ # Calculate negative entropy if requested
65
+ if neg_entropy:
66
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
67
+ p = F.softmax(logits, dim=-1)
68
+ epsilon = 1e-10
69
+ log_probs = torch.log(p + epsilon)
70
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
71
+ else:
72
+ confidence_scores = x0_p
73
+
74
+ x0 = torch.where(mask_index, x0, x)
75
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
76
+
77
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
78
+ if threshold is not None:
79
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
80
+ # print(f'confidence: {confidence}')
81
+ for j in range(confidence.shape[0]):
82
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
83
+ transfer_index[j, select_index] = True
84
+ if threshold is not None:
85
+ for k in range(1, num_transfer_tokens[j]):
86
+ if confidence[j, select_index[k]] < threshold:
87
+ transfer_index[j, select_index[k]] = False
88
+ return x0, transfer_index
89
+
90
+
91
+ def get_num_transfer_tokens(mask_index, steps: int):
92
+ mask_num = mask_index.sum(dim=1, keepdim=True)
93
+ base = mask_num // steps
94
+ remainder = mask_num % steps
95
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
96
+ for i in range(mask_num.size(0)):
97
+ num_transfer_tokens[i, : int(remainder[i])] += 1
98
+ return num_transfer_tokens
99
+
100
+
101
+ @torch.no_grad()
102
+ def generate_with_prefix_cache_block_diff(
103
+ model,
104
+ prompt,
105
+ steps=128,
106
+ gen_length=128,
107
+ block_length=128,
108
+ temperature=0.,
109
+ remasking='low_confidence',
110
+ mask_id=126336,
111
+ threshold=None,
112
+ factor=None,
113
+ shift_logits=False,
114
+ neg_entropy=False,
115
+ causal_context=False,
116
+ pixel_values=None,
117
+ image_sizes=None,
118
+ eos_token_id=None,
119
+ ):
120
+ dream_style=shift_logits
121
+ x_accum = prompt.clone()
122
+
123
+ assert gen_length % block_length == 0
124
+ num_blocks = gen_length // block_length
125
+
126
+ assert steps % num_blocks == 0
127
+ steps_per_block = steps // num_blocks
128
+
129
+ nfe = 0
130
+
131
+ if causal_context:
132
+ model_module = model.module if hasattr(model, "module") else model
133
+ for layer in model_module.encoder.layers:
134
+ if hasattr(layer.self_attn, 'diffusion_lm'):
135
+ layer.self_attn.diffusion_lm=False
136
+
137
+ # Compute KV cache for the prompt initially
138
+ # Pass pixel_values/image_sizes only for this first call (prompt contains image tokens)
139
+ output = model(prompt, use_cache=True, use_causal_mask=causal_context,
140
+ pixel_values=pixel_values, image_sizes=image_sizes)
141
+ past_key_values = output.past_key_values
142
+
143
+ if causal_context:
144
+ for layer in model_module.encoder.layers:
145
+ if hasattr(layer.self_attn, 'diffusion_lm'):
146
+ layer.self_attn.diffusion_lm=True
147
+
148
+ # For dream_style: store the "next token logit" of the context
149
+ next_logits_context = None
150
+ if dream_style:
151
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
152
+
153
+ for num_block in range(num_blocks):
154
+ # Create a new block with mask tokens (no seeding)
155
+ mask_block = torch.ones(
156
+ (prompt.shape[0], block_length),
157
+ dtype=prompt.dtype,
158
+ device=prompt.device
159
+ ) * mask_id
160
+
161
+ # Append the block of masks
162
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
163
+ current_block_start = prompt.size(1) + num_block * block_length
164
+ block_slice = slice(current_block_start, current_block_start + block_length)
165
+
166
+ # Build the initial mask for this block
167
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
168
+
169
+ # Precompute the transfer schedule for this block
170
+ if dream_style:
171
+ # still denoise *all* positions (0..Lb-1), since none are seeded
172
+ schedule_mask = mask_block_idx0
173
+ else:
174
+ schedule_mask = mask_block_idx0
175
+
176
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
177
+
178
+ # Denoise the current block
179
+ for i in range(steps_per_block):
180
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
181
+ if mask_block_idx.sum() == 0:
182
+ break
183
+
184
+ nfe += 1
185
+
186
+ # Forward only the current noisy block using cached context
187
+ logits_block = model(
188
+ x_accum[:, block_slice],
189
+ past_key_values=past_key_values,
190
+ use_cache=False
191
+ ).logits
192
+
193
+ if dream_style:
194
+ # Align logits so that each masked position has a predictor:
195
+ # prepend context-next logit, then use logits_block[:-1]
196
+ if block_length == 1:
197
+ logits_use = next_logits_context # (B, 1, V)
198
+ else:
199
+ logits_use = torch.cat(
200
+ [next_logits_context, logits_block[:, :-1, :]],
201
+ dim=1
202
+ ) # (B, Lb, V)
203
+
204
+ mask_use = mask_block_idx # (B, Lb)
205
+ x_use = x_accum[:, block_slice] # (B, Lb)
206
+
207
+ x0, transfer_idx = get_transfer_index(
208
+ logits_use, temperature, remasking, mask_use, x_use,
209
+ num_transfer_tokens=num_transfer_tokens[:, i],
210
+ threshold=threshold, neg_entropy=neg_entropy
211
+ )
212
+ cur = x_accum[:, block_slice].clone()
213
+ cur[transfer_idx] = x0[transfer_idx]
214
+ x_accum[:, block_slice] = cur
215
+
216
+ else:
217
+ # non-AR (same-position) case
218
+ x0, transfer_idx = get_transfer_index(
219
+ logits_block, temperature, remasking, mask_block_idx,
220
+ x_accum[:, block_slice],
221
+ num_transfer_tokens=num_transfer_tokens[:, i],
222
+ threshold=threshold, neg_entropy=neg_entropy
223
+ )
224
+ cur = x_accum[:, block_slice].clone()
225
+ cur[transfer_idx] = x0[transfer_idx]
226
+ x_accum[:, block_slice] = cur
227
+
228
+ if eos_token_id is not None:
229
+ block_tokens = x_accum[:, block_slice] # (B, Lb)
230
+ eos_mask = (block_tokens == eos_token_id) # (B, Lb)
231
+ any_eos = eos_mask.any(dim=1) # (B,)
232
+ if any_eos.any():
233
+ after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
234
+ mask_before = (block_tokens == mask_id) & ~after_eos
235
+ if (any_eos & ~mask_before.any(dim=1)).any():
236
+ break
237
+
238
+ if causal_context:
239
+ for layer in model_module.encoder.layers:
240
+ if hasattr(layer.self_attn, 'diffusion_lm'):
241
+ layer.self_attn.diffusion_lm=False
242
+
243
+ # after block is fully denoised, update KV cache
244
+ output = model(
245
+ x_accum[:, block_slice],
246
+ past_key_values=past_key_values,
247
+ use_cache=True,
248
+ use_causal_mask=causal_context
249
+ )
250
+ past_key_values = output.past_key_values
251
+
252
+ if causal_context:
253
+ for layer in model_module.encoder.layers:
254
+ if hasattr(layer.self_attn, 'diffusion_lm'):
255
+ layer.self_attn.diffusion_lm=True
256
+
257
+ if dream_style and num_block < num_blocks - 1:
258
+ # refresh context-next logit for the next block
259
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
260
+
261
+ if eos_token_id is not None:
262
+ gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
263
+ is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
264
+ has_eos = is_eos.any(dim=1) # (B,)
265
+ if has_eos.all():
266
+ return x_accum, nfe
267
+
268
+ # first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
269
+ # max_eos = first_eos_pos.max().item()
270
+ # return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
271
+
272
+ return x_accum, nfe
config.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ada_dlm_loss_ratio": null,
3
+ "ada_perm_ratio_global": null,
4
+ "ada_perm_ratio_per_block": null,
5
+ "adaptive_mask_rate": false,
6
+ "always_mask_im_end": false,
7
+ "ar_loss_weight": 1.0,
8
+ "architectures": [
9
+ "NemotronLabsDiffusionVLMModel"
10
+ ],
11
+ "attention_bias": false,
12
+ "attention_dropout": 0.0,
13
+ "attn_implementation": null,
14
+ "auto_map": {
15
+ "AutoConfig": "configuration_nemotron_labs_diffusion_vlm.NemotronLabsDiffusionVLMConfig",
16
+ "AutoModel": "modeling_nemotron_labs_diffusion_vlm.NemotronLabsDiffusionVLMModel",
17
+ "AutoModelForCausalLM": "modeling_nemotron_labs_diffusion_vlm.NemotronLabsDiffusionVLMModel"
18
+ },
19
+ "block_size": 32,
20
+ "bos_token_id": 1,
21
+ "complementary_mask": true,
22
+ "diff_loss_weight": 0.5,
23
+ "dlm_arch": "encoder",
24
+ "dlm_loss_weight": 0.5,
25
+ "dlm_paradigm": "bidirectional",
26
+ "dlm_type": "llada",
27
+ "dp_varying_mask_ratio": false,
28
+ "dtype": "bfloat16",
29
+ "enforce_mask": false,
30
+ "eos_token_id": 11,
31
+ "global_loss_avg": true,
32
+ "head_dim": 128,
33
+ "hidden_act": "silu",
34
+ "hidden_size": 4096,
35
+ "im_end_token_id": 11,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 14336,
38
+ "mask_token_id": 100,
39
+ "max_position_embeddings": 262144,
40
+ "mlp_bias": false,
41
+ "model_type": "nemotron_labs_diffusion_vlm",
42
+ "multi_sampling": null,
43
+ "multimodal_projector_bias": false,
44
+ "num_ar_layers": 0,
45
+ "num_attention_heads": 32,
46
+ "num_diffusion_layers": 0,
47
+ "num_hidden_layers": 34,
48
+ "num_key_value_heads": 8,
49
+ "num_skip_loss_tokens": 0,
50
+ "pad_token_id": 11,
51
+ "prefix_ratio": 0.8,
52
+ "projector_hidden_act": "gelu",
53
+ "random_length_prob": 0,
54
+ "rms_norm_eps": 1e-05,
55
+ "rope_parameters": {
56
+ "beta_fast": 32.0,
57
+ "beta_slow": 1.0,
58
+ "factor": 16.0,
59
+ "llama_4_scaling_beta": 0.1,
60
+ "mscale": 1.0,
61
+ "mscale_all_dim": 1.0,
62
+ "original_max_position_embeddings": 16384,
63
+ "rope_theta": 1000000.0,
64
+ "rope_type": "yarn",
65
+ "type": "yarn"
66
+ },
67
+ "rope_scaling": {
68
+ "beta_fast": 32.0,
69
+ "beta_slow": 1.0,
70
+ "factor": 16.0,
71
+ "llama_4_scaling_beta": 0.1,
72
+ "mscale": 1.0,
73
+ "mscale_all_dim": 1.0,
74
+ "original_max_position_embeddings": 16384,
75
+ "rope_theta": 1000000.0,
76
+ "rope_type": "yarn",
77
+ "type": "yarn"
78
+ },
79
+ "rope_theta": 1000000.0,
80
+ "sliding_window": null,
81
+ "spatial_merge_size": 2,
82
+ "tie_word_embeddings": false,
83
+ "tok_mask_half_life_ratio": null,
84
+ "transformers_version": "4.57.1",
85
+ "use_cache": false,
86
+ "vision_config": {
87
+ "attention_dropout": 0.0,
88
+ "head_dim": 64,
89
+ "hidden_act": "silu",
90
+ "hidden_size": 1024,
91
+ "image_size": 1540,
92
+ "initializer_range": 0.02,
93
+ "intermediate_size": 4096,
94
+ "model_type": "pixtral",
95
+ "num_attention_heads": 16,
96
+ "num_channels": 3,
97
+ "num_hidden_layers": 24,
98
+ "patch_size": 14,
99
+ "rope_parameters": {
100
+ "rope_theta": 10000.0,
101
+ "rope_type": "default"
102
+ }
103
+ },
104
+ "vision_feature_layer": -1,
105
+ "vocab_size": 131073
106
+ }
configuration_nemotron_labs_diffusion_vlm.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Nemotron-Labs Diffusion VLM model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class NemotronLabsDiffusionVLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
+ It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 131072):
35
+ Vocabulary size of the Ministral model.
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 14336):
39
+ Dimension of the MLP representations.
40
+ num_hidden_layers (`int`, *optional*, defaults to 34):
41
+ Number of hidden layers in the Transformer decoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 32):
43
+ Number of attention heads for each attention layer.
44
+ num_key_value_heads (`int`, *optional*, defaults to 8):
45
+ Number of key_value heads for Grouped Query Attention.
46
+ head_dim (`int`, *optional*, defaults to 128):
47
+ The attention head dimension.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function.
50
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
51
+ The maximum sequence length.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions.
58
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
59
+ Whether the model's input and output word embeddings should be tied.
60
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
61
+ The base period of the RoPE embeddings.
62
+ rope_parameters (`Dict`, *optional*):
63
+ Dictionary containing the scaling configuration for the RoPE embeddings.
64
+ Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
65
+ attention_bias (`bool`, defaults to `False`):
66
+ Whether to use a bias in the query, key, value and output projection layers.
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ mlp_bias (`bool`, *optional*, defaults to `False`):
70
+ Whether to use a bias in up_proj, down_proj and gate_proj layers.
71
+ sliding_window (`int`, *optional*, defaults to None):
72
+ Sliding window attention size.
73
+ mask_token_id (`int`, *optional*, defaults to -1):
74
+ Token ID for masking in diffusion.
75
+ dlm_type (`str`, *optional*, defaults to 'llada'):
76
+ Type of diffusion language model ('llada', 'dream').
77
+ random_length_prob (`float`, *optional*):
78
+ Probability of using random lengths during training.
79
+ num_ar_layers (`int`, *optional*, defaults to 0):
80
+ Number of autoregressive layers.
81
+ num_diffusion_layers (`int`, *optional*, defaults to 0):
82
+ Number of diffusion layers.
83
+ diff_loss_weight (`float`, *optional*, defaults to 1):
84
+ Weight for diffusion loss.
85
+ enforce_mask (`bool`, *optional*, defaults to False):
86
+ Whether to enforce masking.
87
+ prefix_ratio (`float`, *optional*, defaults to 0.8):
88
+ Ratio for prefix in prefix_bidirectional mode.
89
+ dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
90
+ Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
91
+ dlm_arch (`str`, *optional*, defaults to 'encoder'):
92
+ Architecture type ('encoder', 'encoder_decoder').
93
+ block_size (`int`, *optional*, defaults to 32):
94
+ Block size for block diffusion paradigms.
95
+ tok_mask_half_life_ratio (`float`, *optional*):
96
+ Half-life ratio for token masking.
97
+ adaptive_mask_rate (`bool`, *optional*, defaults to False):
98
+ Whether to use adaptive mask rate.
99
+ multi_sampling (`int`, *optional*):
100
+ Number of samples for multi-sampling.
101
+ num_skip_loss_tokens (`int`, *optional*, defaults to 0):
102
+ Number of tokens to skip in loss calculation.
103
+ dlm_loss_weight (`float`, *optional*):
104
+ Weight for diffusion LM loss.
105
+ ar_loss_weight (`float`, *optional*, defaults to 1.0):
106
+ Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
107
+ global_loss_avg (`bool`, *optional*, defaults to False):
108
+ Whether to use global loss average.
109
+ dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
110
+ Whether to use varying mask ratio for each DP rank during sampling.
111
+ ada_perm_ratio_per_block (`float`, *optional*):
112
+ Adaptive permutation ratio for each block.
113
+ ada_perm_ratio_global (`float`, *optional*):
114
+ Adaptive permutation ratio for global.
115
+ complementary_mask (`bool`, *optional*, defaults to False):
116
+ Whether to use complementary masking (mask + inverse mask).
117
+ always_mask_im_end (`bool`, *optional*, defaults to False):
118
+ Whether to always mask im_end tokens.
119
+ im_end_token_id (`int`, *optional*, defaults to 11):
120
+ Token ID for im_end in always_mask_im_end.
121
+ """
122
+
123
+ model_type = "nemotron_labs_diffusion_vlm"
124
+ keys_to_ignore_at_inference = ["past_key_values"]
125
+
126
+ # Default tensor parallel plan for base model `Ministral`
127
+ base_model_tp_plan = {
128
+ "layers.*.self_attn.q_proj": "colwise",
129
+ "layers.*.self_attn.k_proj": "colwise",
130
+ "layers.*.self_attn.v_proj": "colwise",
131
+ "layers.*.self_attn.o_proj": "rowwise",
132
+ "layers.*.mlp.gate_proj": "colwise",
133
+ "layers.*.mlp.up_proj": "colwise",
134
+ "layers.*.mlp.down_proj": "rowwise",
135
+ }
136
+ base_model_pp_plan = {
137
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
138
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
139
+ "norm": (["hidden_states"], ["hidden_states"]),
140
+ }
141
+
142
+ def __init__(
143
+ self,
144
+ vocab_size=131072,
145
+ hidden_size=4096,
146
+ intermediate_size=14336,
147
+ num_hidden_layers=34,
148
+ num_attention_heads=32,
149
+ num_key_value_heads=8,
150
+ head_dim=128,
151
+ hidden_act="silu",
152
+ max_position_embeddings=262144,
153
+ initializer_range=0.02,
154
+ rms_norm_eps=1e-05,
155
+ use_cache=True,
156
+ pad_token_id=None,
157
+ bos_token_id=1,
158
+ eos_token_id=2,
159
+ tie_word_embeddings=False,
160
+ rope_theta=1000000.0,
161
+ rope_parameters=None,
162
+ rope_scaling=None,
163
+ attention_bias=False,
164
+ attention_dropout=0.0,
165
+ mlp_bias=False,
166
+ sliding_window=None,
167
+ attn_implementation="sdpa",
168
+ mask_token_id=-1,
169
+ dlm_type='llada',
170
+ random_length_prob=None,
171
+ num_ar_layers=0,
172
+ num_diffusion_layers=0,
173
+ diff_loss_weight=1,
174
+ enforce_mask=False,
175
+ prefix_ratio=0.8,
176
+ dlm_paradigm='bidirectional',
177
+ dlm_arch='encoder',
178
+ block_size=32,
179
+ tok_mask_half_life_ratio=None,
180
+ adaptive_mask_rate=False,
181
+ multi_sampling=None,
182
+ num_skip_loss_tokens=0,
183
+ dlm_loss_weight=None,
184
+ ar_loss_weight=1.0,
185
+ global_loss_avg=False,
186
+ dp_varying_mask_ratio=False,
187
+ ada_perm_ratio_per_block=None,
188
+ ada_perm_ratio_global=None,
189
+ ada_dlm_loss_ratio=None,
190
+ complementary_mask=False,
191
+ always_mask_im_end=False,
192
+ im_end_token_id=11,
193
+ **kwargs,
194
+ ):
195
+ self.vocab_size = vocab_size
196
+ self.max_position_embeddings = max_position_embeddings
197
+ self.hidden_size = hidden_size
198
+ self.intermediate_size = intermediate_size
199
+ self.num_hidden_layers = num_hidden_layers
200
+ self.num_attention_heads = num_attention_heads
201
+
202
+ # for backward compatibility
203
+ if num_key_value_heads is None:
204
+ num_key_value_heads = num_attention_heads
205
+
206
+ self.num_key_value_heads = num_key_value_heads
207
+ self.head_dim = head_dim
208
+ self.hidden_act = hidden_act
209
+ self.initializer_range = initializer_range
210
+ self.rms_norm_eps = rms_norm_eps
211
+ self.use_cache = use_cache
212
+ self.rope_theta = rope_theta
213
+ self.rope_parameters = rope_parameters
214
+ self.rope_scaling = rope_scaling
215
+ self.attention_bias = attention_bias
216
+ self.attention_dropout = attention_dropout
217
+ self.mlp_bias = mlp_bias
218
+ self.sliding_window = sliding_window
219
+
220
+ rope_config_validation(self)
221
+
222
+ self.attn_implementation = attn_implementation
223
+
224
+ self.mask_token_id = mask_token_id
225
+ self.dlm_type = dlm_type
226
+ self.random_length_prob = random_length_prob
227
+ self.num_ar_layers = num_ar_layers
228
+ self.num_diffusion_layers = num_diffusion_layers
229
+ self.diff_loss_weight = diff_loss_weight
230
+ self.enforce_mask = enforce_mask
231
+ self.prefix_ratio = prefix_ratio
232
+ self.dlm_paradigm = dlm_paradigm
233
+ self.dlm_arch = dlm_arch
234
+ self.block_size = block_size
235
+ self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
236
+ self.adaptive_mask_rate = adaptive_mask_rate
237
+ self.multi_sampling = multi_sampling
238
+ self.num_skip_loss_tokens = num_skip_loss_tokens
239
+ self.dlm_loss_weight = dlm_loss_weight
240
+ self.ar_loss_weight = ar_loss_weight
241
+ self.global_loss_avg = global_loss_avg
242
+ self.dp_varying_mask_ratio = dp_varying_mask_ratio
243
+ self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
244
+ self.ada_perm_ratio_global = ada_perm_ratio_global
245
+ self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
246
+ self.complementary_mask = complementary_mask
247
+ self.always_mask_im_end = always_mask_im_end
248
+ self.im_end_token_id = im_end_token_id
249
+ super().__init__(
250
+ pad_token_id=pad_token_id,
251
+ bos_token_id=bos_token_id,
252
+ eos_token_id=eos_token_id,
253
+ tie_word_embeddings=tie_word_embeddings,
254
+ **kwargs,
255
+ )
256
+
257
+
258
+ __all__ = ["NemotronLabsDiffusionVLMConfig"]
259
+
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": [
5
+ 11,
6
+ 2
7
+ ],
8
+ "pad_token_id": 11,
9
+ "transformers_version": "4.57.1",
10
+ "use_cache": false
11
+ }
image_processing.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image processing utilities for Nemotron-Diffusion-Exp-Ministral-8B-Instruct (final-template).
3
+
4
+ Implements image token expansion and pixel value preprocessing,
5
+ faithfully ported from mistral_common.tokens.tokenizers.image.ImageEncoder
6
+ to ensure identical image sizing and token counts.
7
+
8
+ Special token mapping (final-template version):
9
+ <|image_start|> (id=18) = [IMG_START] image start marker
10
+ <|image_pad|> (id=19) = [IMG] image pad token (one per merged patch)
11
+ <|image_break|> (id=20) = [IMG_BREAK] image row break
12
+ <|image_end|> (id=21) = [IMG_END] image end marker
13
+
14
+ After expansion, each image placeholder becomes:
15
+ [IMG_START] ([IMG]*W [IMG_BREAK]) * (H-1) [IMG]*W [IMG_END]
16
+
17
+ where W = width_tokens, H = height_tokens (computed via ceiling division
18
+ on the original image dims, matching mistral_common exactly).
19
+ """
20
+
21
+ import os
22
+ from io import BytesIO
23
+ from typing import Any, Dict, List, Tuple, Union
24
+
25
+ import cv2
26
+ import numpy as np
27
+ import requests
28
+ import torch
29
+ from PIL import Image
30
+
31
+
32
+ # ── Token strings (must match tokenizer_config.json) ──────────────────────────
33
+ IMG_START_TOKEN = "<|image_start|>" # id = 18
34
+ IMG_PAD_TOKEN = "<|image_pad|>" # id = 19
35
+ IMG_BREAK_TOKEN = "<|image_break|>" # id = 20
36
+ IMG_END_TOKEN = "<|image_end|>" # id = 21
37
+
38
+ # ── Token IDs ─────────────────────────────────────────────────────────────────
39
+ IMG_START_ID = 18
40
+ IMG_PAD_ID = 19
41
+ IMG_BREAK_ID = 20
42
+ IMG_END_ID = 21
43
+
44
+ # ── Default config (from config.json / processor_config.json) ─────────────────
45
+ DEFAULT_PATCH_SIZE = 14
46
+ DEFAULT_SPATIAL_MERGE_SIZE = 2
47
+ DEFAULT_MAX_IMAGE_SIZE = 1400 # longest edge
48
+ # Allow override via environment variable (e.g. from run_all_benchmarks.sh)
49
+ _env_max = os.environ.get("DEFAULT_MAX_IMAGE_SIZE")
50
+ if _env_max is not None and str(_env_max).strip():
51
+ try:
52
+ DEFAULT_MAX_IMAGE_SIZE = int(_env_max)
53
+ except ValueError:
54
+ pass
55
+
56
+ DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) # RGB
57
+ DATASET_STD = (0.26862954, 0.26130258, 0.27577711) # RGB
58
+
59
+
60
+ # ══════════════════════════════════════════════════════════════════════════════
61
+ # Image loading (mirrors mistral_common.tokens.tokenizers.image)
62
+ # ══════════════════════════════════════════════════════════════════════════════
63
+
64
+ def _convert_to_rgb(image: Image.Image) -> Image.Image:
65
+ """Convert PIL image to RGB; transparent backgrounds become white."""
66
+ if image.mode == "RGB":
67
+ return image
68
+ if image.mode != "RGBA":
69
+ image = image.convert("RGBA")
70
+ white_bg = Image.new("RGBA", image.size, "WHITE")
71
+ white_bg.paste(image, (0, 0), image)
72
+ return white_bg.convert("RGB")
73
+
74
+
75
+ def load_image(source: Union[str, Image.Image]) -> Image.Image:
76
+ """Load an image from a URL, local file path, or PIL Image."""
77
+ if isinstance(source, Image.Image):
78
+ return source
79
+ if source.startswith(("http://", "https://")):
80
+ resp = requests.get(source, stream=True, timeout=30)
81
+ resp.raise_for_status()
82
+ return Image.open(BytesIO(resp.content))
83
+ return Image.open(source)
84
+
85
+
86
+ # ══════════════════════════════════════════════════════════════════════════════
87
+ # Core logic — ported from mistral_common ImageEncoder
88
+ # ══════════════════════════════════════════════════════════════════════════════
89
+
90
+ def _image_to_num_tokens(
91
+ img: Image.Image,
92
+ image_patch_size: int = DEFAULT_PATCH_SIZE,
93
+ max_image_size: int = DEFAULT_MAX_IMAGE_SIZE,
94
+ spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE,
95
+ ) -> Tuple[int, int]:
96
+ """
97
+ Compute (width_tokens, height_tokens) for a given image — identical to
98
+ ``mistral_common.tokens.tokenizers.image.ImageEncoder._image_to_num_tokens``.
99
+ """
100
+ w, h = img.size # PIL: (W, H)
101
+ ratio = max(h / max_image_size, w / max_image_size)
102
+ if ratio > 1:
103
+ w = round(w / ratio)
104
+ h = round(h / ratio)
105
+
106
+ width_tokens = (w - 1) // (image_patch_size * spatial_merge_size) + 1
107
+ height_tokens = (h - 1) // (image_patch_size * spatial_merge_size) + 1
108
+ return width_tokens, height_tokens
109
+
110
+
111
+ def transform_image(
112
+ image: Image.Image,
113
+ new_size: Tuple[int, int],
114
+ mean: Tuple[float, ...] = DATASET_MEAN,
115
+ std: Tuple[float, ...] = DATASET_STD,
116
+ ) -> np.ndarray:
117
+ """
118
+ Resize + normalise — identical to
119
+ ``mistral_common.tokens.tokenizers.image.transform_image``.
120
+
121
+ Args:
122
+ image: PIL Image (any mode).
123
+ new_size: Target (W, H) — cv2 convention.
124
+
125
+ Returns:
126
+ np.ndarray of shape (C, H, W), float32, normalised.
127
+ """
128
+ np_image = cv2.resize(
129
+ np.array(_convert_to_rgb(image), dtype=np.float32),
130
+ new_size,
131
+ interpolation=cv2.INTER_CUBIC,
132
+ )
133
+ np_image = np_image / 255.0
134
+ np_image = (np_image - np.array(mean, dtype=np.float32)) / np.array(std, dtype=np.float32)
135
+ return np_image.transpose(2, 0, 1)
136
+
137
+
138
+ def encode_image(
139
+ image: Image.Image,
140
+ image_patch_size: int = DEFAULT_PATCH_SIZE,
141
+ max_image_size: int = DEFAULT_MAX_IMAGE_SIZE,
142
+ spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE,
143
+ ) -> Tuple[int, int, np.ndarray]:
144
+ """
145
+ Compute token dimensions **and** preprocessed pixel array for one image.
146
+
147
+ Returns:
148
+ (width_tokens, height_tokens, pixel_array)
149
+ where pixel_array has shape (C, H, W).
150
+ """
151
+ w_tok, h_tok = _image_to_num_tokens(
152
+ image, image_patch_size, max_image_size, spatial_merge_size,
153
+ )
154
+ assert w_tok > 0 and h_tok > 0
155
+
156
+ new_w = w_tok * image_patch_size * spatial_merge_size
157
+ new_h = h_tok * image_patch_size * spatial_merge_size
158
+ processed = transform_image(image, (new_w, new_h)) # cv2: (W, H)
159
+
160
+ return w_tok, h_tok, processed
161
+
162
+
163
+ # ══════════════════════════════════════════════════════════════════════════════
164
+ # Token string expansion
165
+ # ══════════════════════════════════════════════════════════════════════════════
166
+
167
+ def build_image_token_str(w_tokens: int, h_tokens: int) -> str:
168
+ """
169
+ Build the expanded image-token string for one image.
170
+
171
+ Pattern:
172
+ [IMG_START]
173
+ ([IMG]*W [IMG_BREAK]) * (H-1)
174
+ [IMG]*W [IMG_END]
175
+ """
176
+ row = IMG_PAD_TOKEN * w_tokens + IMG_BREAK_TOKEN
177
+ body = row * h_tokens
178
+ body = body[: -len(IMG_BREAK_TOKEN)] + IMG_END_TOKEN
179
+
180
+ return IMG_START_TOKEN + body
181
+
182
+
183
+ # ══════════════════════════════════════════════════════════════════════════════
184
+ # Extract image sources from OpenAI-style messages
185
+ # ══════════════════════════════════════════════════════════════════════════════
186
+
187
+ def _extract_image_sources(messages: List[Dict[str, Any]]) -> List[str]:
188
+ """Walk through OpenAI-style messages and collect image URLs / paths."""
189
+ sources: List[str] = []
190
+ for msg in messages:
191
+ content = msg.get("content", "")
192
+ if not isinstance(content, list):
193
+ continue
194
+ for block in content:
195
+ btype = block.get("type")
196
+ if btype == "image_url":
197
+ url_obj = block.get("image_url", {})
198
+ sources.append(url_obj.get("url", ""))
199
+ elif btype == "image":
200
+ for key in ("url", "path", "image"):
201
+ if key in block:
202
+ sources.append(block[key])
203
+ break
204
+ return sources
205
+
206
+
207
+ # ══════════════════════════════════════════════════════════════════════════════
208
+ # Public API
209
+ # ══════════════════════════════════════════════════════════════════════════════
210
+
211
+ def process_messages(
212
+ tokenizer,
213
+ messages: List[Dict[str, Any]],
214
+ *,
215
+ patch_size: int = DEFAULT_PATCH_SIZE,
216
+ spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE,
217
+ max_image_size: int = DEFAULT_MAX_IMAGE_SIZE,
218
+ return_tensors: str = "pt",
219
+ add_generation_prompt: bool = False,
220
+ enable_thinking: bool = True,
221
+ ) -> Dict[str, Any]:
222
+ """
223
+ Process chat messages with optional images — drop-in replacement for
224
+ ``MistralCommonBackend.apply_chat_template(return_dict=True)``.
225
+
226
+ Steps:
227
+ 1. Render Jinja chat template → prompt with ``<|image_start|>`` placeholders.
228
+ 2. For each image:
229
+ a. Load image.
230
+ b. Compute token dims via ceiling division (matching mistral_common).
231
+ c. Resize to token-aligned dimensions with cv2 INTER_CUBIC.
232
+ d. Normalise pixels.
233
+ e. Replace the next ``<|image_start|>`` placeholder with the expanded
234
+ token sequence.
235
+ 3. Tokenize the expanded prompt.
236
+ 4. Return dict with ``input_ids`` (and ``pixel_values`` / ``image_sizes``
237
+ if images are present).
238
+
239
+ Args:
240
+ enable_thinking: When True (default), the generation prompt opens a
241
+ ``<think>`` block for chain-of-thought reasoning. When False,
242
+ an empty ``<think></think>`` is emitted so the model skips
243
+ the thinking phase.
244
+
245
+ Returns:
246
+ dict with keys:
247
+ input_ids : LongTensor (1, seq_len)
248
+ pixel_values : FloatTensor (N, 3, H, W) – only when images present
249
+ image_sizes : list of (H, W) tuples – only when images present
250
+ """
251
+ # ── 1. Extract image sources ──────────────────────────────────────────
252
+ image_sources = _extract_image_sources(messages)
253
+
254
+ # ── 2. Render chat template (produces <|image_start|> placeholders) ──
255
+ prompt: str = tokenizer.apply_chat_template(
256
+ messages,
257
+ tokenize=False,
258
+ add_generation_prompt=add_generation_prompt,
259
+ enable_thinking=enable_thinking,
260
+ )
261
+
262
+ # ── 3. Expand each placeholder & preprocess images ────────────────────
263
+ pixel_list: List[np.ndarray] = []
264
+ image_sizes: List[Tuple[int, int]] = []
265
+
266
+ for src in image_sources:
267
+ pil_img = load_image(src)
268
+
269
+ w_tok, h_tok, pixels = encode_image(
270
+ pil_img, patch_size, max_image_size, spatial_merge_size,
271
+ )
272
+
273
+ expanded = build_image_token_str(w_tok, h_tok)
274
+ prompt = prompt.replace(IMG_START_TOKEN, expanded, 1)
275
+
276
+ pixel_list.append(pixels)
277
+ final_h = h_tok * patch_size * spatial_merge_size
278
+ final_w = w_tok * patch_size * spatial_merge_size
279
+ image_sizes.append((final_h, final_w))
280
+
281
+ # ── 4. Tokenize ──────────────────────────────────────────────────────
282
+ if return_tensors == "pt":
283
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
284
+ else:
285
+ input_ids = tokenizer(prompt).input_ids
286
+
287
+ result: Dict[str, Any] = {"input_ids": input_ids}
288
+
289
+ if pixel_list:
290
+ if return_tensors == "pt":
291
+ result["pixel_values"] = torch.from_numpy(np.stack(pixel_list))
292
+ else:
293
+ result["pixel_values"] = np.stack(pixel_list)
294
+ result["image_sizes"] = image_sizes
295
+
296
+ return result
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6bdc8ac3f1baef94a6a6fe6c5290031495e45ab18a7884f210dd8c157b33582
3
+ size 4984302088
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e4c3f9cac913ee79685bb55f4eee4c7220ae03ad50bfe83131401a0df023706
3
+ size 4999802904
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:562159c86be4c9cdde7250b5b78f4636746fd1e82c9e2e10310f74d5696c3503
3
+ size 4915916376
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7323f1fdacbf54bee85a2b4b074f88767c8d51754eb12495ce48f3f152048276
3
+ size 2936115968
model.safetensors.index.json ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 8918034432,
4
+ "total_size": 17836068864
5
+ },
6
+ "weight_map": {
7
+ "diffusion_head.weight": "model-00004-of-00004.safetensors",
8
+ "encoder.embed_tokens.weight": "model-00001-of-00004.safetensors",
9
+ "encoder.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
10
+ "encoder.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
11
+ "encoder.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
12
+ "encoder.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
13
+ "encoder.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
14
+ "encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "encoder.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
17
+ "encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
18
+ "encoder.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
19
+ "encoder.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
20
+ "encoder.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
21
+ "encoder.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
22
+ "encoder.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
23
+ "encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
24
+ "encoder.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
25
+ "encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
26
+ "encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
27
+ "encoder.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
28
+ "encoder.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
29
+ "encoder.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
30
+ "encoder.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
31
+ "encoder.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
32
+ "encoder.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
33
+ "encoder.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
34
+ "encoder.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
35
+ "encoder.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
36
+ "encoder.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "encoder.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "encoder.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
39
+ "encoder.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
40
+ "encoder.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
41
+ "encoder.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
42
+ "encoder.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
43
+ "encoder.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
44
+ "encoder.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
45
+ "encoder.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "encoder.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
47
+ "encoder.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
48
+ "encoder.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
49
+ "encoder.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
50
+ "encoder.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
51
+ "encoder.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
52
+ "encoder.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
53
+ "encoder.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
54
+ "encoder.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
55
+ "encoder.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
56
+ "encoder.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
57
+ "encoder.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
58
+ "encoder.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
59
+ "encoder.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
60
+ "encoder.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
61
+ "encoder.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
62
+ "encoder.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
63
+ "encoder.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
64
+ "encoder.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
65
+ "encoder.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
66
+ "encoder.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
67
+ "encoder.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
68
+ "encoder.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
69
+ "encoder.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
70
+ "encoder.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
71
+ "encoder.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
72
+ "encoder.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "encoder.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
74
+ "encoder.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
75
+ "encoder.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
76
+ "encoder.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
77
+ "encoder.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
78
+ "encoder.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
79
+ "encoder.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
80
+ "encoder.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
81
+ "encoder.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
82
+ "encoder.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
83
+ "encoder.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
84
+ "encoder.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
85
+ "encoder.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
86
+ "encoder.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
87
+ "encoder.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
88
+ "encoder.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "encoder.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
90
+ "encoder.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
91
+ "encoder.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
92
+ "encoder.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
93
+ "encoder.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
94
+ "encoder.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
95
+ "encoder.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
96
+ "encoder.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
97
+ "encoder.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
98
+ "encoder.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
99
+ "encoder.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
100
+ "encoder.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
101
+ "encoder.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
102
+ "encoder.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
103
+ "encoder.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
104
+ "encoder.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
105
+ "encoder.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
106
+ "encoder.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
107
+ "encoder.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
108
+ "encoder.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
109
+ "encoder.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
110
+ "encoder.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
111
+ "encoder.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
112
+ "encoder.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
113
+ "encoder.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
114
+ "encoder.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
115
+ "encoder.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
116
+ "encoder.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
117
+ "encoder.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
118
+ "encoder.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
119
+ "encoder.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
120
+ "encoder.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
121
+ "encoder.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
122
+ "encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
123
+ "encoder.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
124
+ "encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
125
+ "encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
126
+ "encoder.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
127
+ "encoder.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
128
+ "encoder.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
129
+ "encoder.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
130
+ "encoder.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
131
+ "encoder.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
132
+ "encoder.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
133
+ "encoder.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
134
+ "encoder.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
135
+ "encoder.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
136
+ "encoder.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
137
+ "encoder.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
138
+ "encoder.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
139
+ "encoder.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
140
+ "encoder.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
141
+ "encoder.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
142
+ "encoder.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
143
+ "encoder.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
144
+ "encoder.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
145
+ "encoder.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
146
+ "encoder.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
147
+ "encoder.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
148
+ "encoder.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
149
+ "encoder.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
150
+ "encoder.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
151
+ "encoder.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
152
+ "encoder.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
153
+ "encoder.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
154
+ "encoder.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
155
+ "encoder.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
156
+ "encoder.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
157
+ "encoder.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
158
+ "encoder.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
159
+ "encoder.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
160
+ "encoder.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
161
+ "encoder.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
162
+ "encoder.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
163
+ "encoder.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
164
+ "encoder.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
165
+ "encoder.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
166
+ "encoder.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
167
+ "encoder.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
168
+ "encoder.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
169
+ "encoder.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
170
+ "encoder.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
171
+ "encoder.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
172
+ "encoder.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
173
+ "encoder.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
174
+ "encoder.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
175
+ "encoder.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
176
+ "encoder.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
177
+ "encoder.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
178
+ "encoder.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
179
+ "encoder.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
180
+ "encoder.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
181
+ "encoder.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
182
+ "encoder.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
183
+ "encoder.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
184
+ "encoder.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
185
+ "encoder.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
186
+ "encoder.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
187
+ "encoder.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
188
+ "encoder.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
189
+ "encoder.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
190
+ "encoder.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
191
+ "encoder.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
192
+ "encoder.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
193
+ "encoder.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
194
+ "encoder.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
195
+ "encoder.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
196
+ "encoder.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
197
+ "encoder.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
198
+ "encoder.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
199
+ "encoder.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
200
+ "encoder.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
201
+ "encoder.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
202
+ "encoder.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
203
+ "encoder.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
204
+ "encoder.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
205
+ "encoder.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
206
+ "encoder.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
207
+ "encoder.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
208
+ "encoder.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
209
+ "encoder.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
210
+ "encoder.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
211
+ "encoder.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
212
+ "encoder.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
213
+ "encoder.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
214
+ "encoder.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
215
+ "encoder.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
216
+ "encoder.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
217
+ "encoder.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
218
+ "encoder.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
219
+ "encoder.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
220
+ "encoder.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
221
+ "encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
222
+ "encoder.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
223
+ "encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
224
+ "encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
225
+ "encoder.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
226
+ "encoder.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
227
+ "encoder.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
228
+ "encoder.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
229
+ "encoder.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
230
+ "encoder.layers.30.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
231
+ "encoder.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
232
+ "encoder.layers.30.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
233
+ "encoder.layers.30.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
234
+ "encoder.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
235
+ "encoder.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
236
+ "encoder.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
237
+ "encoder.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
238
+ "encoder.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
239
+ "encoder.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
240
+ "encoder.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
241
+ "encoder.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
242
+ "encoder.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
243
+ "encoder.layers.32.input_layernorm.weight": "model-00004-of-00004.safetensors",
244
+ "encoder.layers.32.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
245
+ "encoder.layers.32.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
246
+ "encoder.layers.32.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
247
+ "encoder.layers.32.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
248
+ "encoder.layers.32.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
249
+ "encoder.layers.32.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
250
+ "encoder.layers.32.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
251
+ "encoder.layers.32.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
252
+ "encoder.layers.33.input_layernorm.weight": "model-00004-of-00004.safetensors",
253
+ "encoder.layers.33.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
254
+ "encoder.layers.33.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
255
+ "encoder.layers.33.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
256
+ "encoder.layers.33.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
257
+ "encoder.layers.33.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
258
+ "encoder.layers.33.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
259
+ "encoder.layers.33.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
260
+ "encoder.layers.33.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
261
+ "encoder.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
262
+ "encoder.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
263
+ "encoder.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
264
+ "encoder.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
265
+ "encoder.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
266
+ "encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
267
+ "encoder.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
268
+ "encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
269
+ "encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
270
+ "encoder.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
271
+ "encoder.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
272
+ "encoder.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
273
+ "encoder.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
274
+ "encoder.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
275
+ "encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
276
+ "encoder.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
277
+ "encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
278
+ "encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
279
+ "encoder.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
280
+ "encoder.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
281
+ "encoder.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
282
+ "encoder.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
283
+ "encoder.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
284
+ "encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
285
+ "encoder.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
286
+ "encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
287
+ "encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
288
+ "encoder.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
289
+ "encoder.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
290
+ "encoder.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
291
+ "encoder.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
292
+ "encoder.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
293
+ "encoder.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
294
+ "encoder.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
295
+ "encoder.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
296
+ "encoder.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
297
+ "encoder.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
298
+ "encoder.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
299
+ "encoder.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
300
+ "encoder.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
301
+ "encoder.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
302
+ "encoder.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
303
+ "encoder.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
304
+ "encoder.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
305
+ "encoder.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
306
+ "encoder.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
307
+ "encoder.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
308
+ "encoder.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
309
+ "encoder.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
310
+ "encoder.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
311
+ "encoder.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
312
+ "encoder.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
313
+ "encoder.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
314
+ "encoder.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
315
+ "encoder.multi_modal_projector.linear_1.weight": "model-00001-of-00004.safetensors",
316
+ "encoder.multi_modal_projector.linear_2.weight": "model-00001-of-00004.safetensors",
317
+ "encoder.multi_modal_projector.norm.weight": "model-00001-of-00004.safetensors",
318
+ "encoder.multi_modal_projector.patch_merger.merging_layer.weight": "model-00001-of-00004.safetensors",
319
+ "encoder.norm.weight": "model-00004-of-00004.safetensors",
320
+ "encoder.vision_tower.ln_pre.weight": "model-00001-of-00004.safetensors",
321
+ "encoder.vision_tower.patch_conv.weight": "model-00001-of-00004.safetensors",
322
+ "encoder.vision_tower.transformer.layers.0.attention.k_proj.weight": "model-00001-of-00004.safetensors",
323
+ "encoder.vision_tower.transformer.layers.0.attention.o_proj.weight": "model-00001-of-00004.safetensors",
324
+ "encoder.vision_tower.transformer.layers.0.attention.q_proj.weight": "model-00001-of-00004.safetensors",
325
+ "encoder.vision_tower.transformer.layers.0.attention.v_proj.weight": "model-00001-of-00004.safetensors",
326
+ "encoder.vision_tower.transformer.layers.0.attention_norm.weight": "model-00001-of-00004.safetensors",
327
+ "encoder.vision_tower.transformer.layers.0.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
328
+ "encoder.vision_tower.transformer.layers.0.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
329
+ "encoder.vision_tower.transformer.layers.0.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
330
+ "encoder.vision_tower.transformer.layers.0.ffn_norm.weight": "model-00001-of-00004.safetensors",
331
+ "encoder.vision_tower.transformer.layers.1.attention.k_proj.weight": "model-00001-of-00004.safetensors",
332
+ "encoder.vision_tower.transformer.layers.1.attention.o_proj.weight": "model-00001-of-00004.safetensors",
333
+ "encoder.vision_tower.transformer.layers.1.attention.q_proj.weight": "model-00001-of-00004.safetensors",
334
+ "encoder.vision_tower.transformer.layers.1.attention.v_proj.weight": "model-00001-of-00004.safetensors",
335
+ "encoder.vision_tower.transformer.layers.1.attention_norm.weight": "model-00001-of-00004.safetensors",
336
+ "encoder.vision_tower.transformer.layers.1.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
337
+ "encoder.vision_tower.transformer.layers.1.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
338
+ "encoder.vision_tower.transformer.layers.1.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
339
+ "encoder.vision_tower.transformer.layers.1.ffn_norm.weight": "model-00001-of-00004.safetensors",
340
+ "encoder.vision_tower.transformer.layers.10.attention.k_proj.weight": "model-00001-of-00004.safetensors",
341
+ "encoder.vision_tower.transformer.layers.10.attention.o_proj.weight": "model-00001-of-00004.safetensors",
342
+ "encoder.vision_tower.transformer.layers.10.attention.q_proj.weight": "model-00001-of-00004.safetensors",
343
+ "encoder.vision_tower.transformer.layers.10.attention.v_proj.weight": "model-00001-of-00004.safetensors",
344
+ "encoder.vision_tower.transformer.layers.10.attention_norm.weight": "model-00001-of-00004.safetensors",
345
+ "encoder.vision_tower.transformer.layers.10.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
346
+ "encoder.vision_tower.transformer.layers.10.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
347
+ "encoder.vision_tower.transformer.layers.10.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
348
+ "encoder.vision_tower.transformer.layers.10.ffn_norm.weight": "model-00001-of-00004.safetensors",
349
+ "encoder.vision_tower.transformer.layers.11.attention.k_proj.weight": "model-00001-of-00004.safetensors",
350
+ "encoder.vision_tower.transformer.layers.11.attention.o_proj.weight": "model-00001-of-00004.safetensors",
351
+ "encoder.vision_tower.transformer.layers.11.attention.q_proj.weight": "model-00001-of-00004.safetensors",
352
+ "encoder.vision_tower.transformer.layers.11.attention.v_proj.weight": "model-00001-of-00004.safetensors",
353
+ "encoder.vision_tower.transformer.layers.11.attention_norm.weight": "model-00001-of-00004.safetensors",
354
+ "encoder.vision_tower.transformer.layers.11.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
355
+ "encoder.vision_tower.transformer.layers.11.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
356
+ "encoder.vision_tower.transformer.layers.11.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
357
+ "encoder.vision_tower.transformer.layers.11.ffn_norm.weight": "model-00001-of-00004.safetensors",
358
+ "encoder.vision_tower.transformer.layers.12.attention.k_proj.weight": "model-00001-of-00004.safetensors",
359
+ "encoder.vision_tower.transformer.layers.12.attention.o_proj.weight": "model-00001-of-00004.safetensors",
360
+ "encoder.vision_tower.transformer.layers.12.attention.q_proj.weight": "model-00001-of-00004.safetensors",
361
+ "encoder.vision_tower.transformer.layers.12.attention.v_proj.weight": "model-00001-of-00004.safetensors",
362
+ "encoder.vision_tower.transformer.layers.12.attention_norm.weight": "model-00001-of-00004.safetensors",
363
+ "encoder.vision_tower.transformer.layers.12.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
364
+ "encoder.vision_tower.transformer.layers.12.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
365
+ "encoder.vision_tower.transformer.layers.12.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
366
+ "encoder.vision_tower.transformer.layers.12.ffn_norm.weight": "model-00001-of-00004.safetensors",
367
+ "encoder.vision_tower.transformer.layers.13.attention.k_proj.weight": "model-00001-of-00004.safetensors",
368
+ "encoder.vision_tower.transformer.layers.13.attention.o_proj.weight": "model-00001-of-00004.safetensors",
369
+ "encoder.vision_tower.transformer.layers.13.attention.q_proj.weight": "model-00001-of-00004.safetensors",
370
+ "encoder.vision_tower.transformer.layers.13.attention.v_proj.weight": "model-00001-of-00004.safetensors",
371
+ "encoder.vision_tower.transformer.layers.13.attention_norm.weight": "model-00001-of-00004.safetensors",
372
+ "encoder.vision_tower.transformer.layers.13.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
373
+ "encoder.vision_tower.transformer.layers.13.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
374
+ "encoder.vision_tower.transformer.layers.13.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
375
+ "encoder.vision_tower.transformer.layers.13.ffn_norm.weight": "model-00001-of-00004.safetensors",
376
+ "encoder.vision_tower.transformer.layers.14.attention.k_proj.weight": "model-00001-of-00004.safetensors",
377
+ "encoder.vision_tower.transformer.layers.14.attention.o_proj.weight": "model-00001-of-00004.safetensors",
378
+ "encoder.vision_tower.transformer.layers.14.attention.q_proj.weight": "model-00001-of-00004.safetensors",
379
+ "encoder.vision_tower.transformer.layers.14.attention.v_proj.weight": "model-00001-of-00004.safetensors",
380
+ "encoder.vision_tower.transformer.layers.14.attention_norm.weight": "model-00001-of-00004.safetensors",
381
+ "encoder.vision_tower.transformer.layers.14.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
382
+ "encoder.vision_tower.transformer.layers.14.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
383
+ "encoder.vision_tower.transformer.layers.14.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
384
+ "encoder.vision_tower.transformer.layers.14.ffn_norm.weight": "model-00001-of-00004.safetensors",
385
+ "encoder.vision_tower.transformer.layers.15.attention.k_proj.weight": "model-00001-of-00004.safetensors",
386
+ "encoder.vision_tower.transformer.layers.15.attention.o_proj.weight": "model-00001-of-00004.safetensors",
387
+ "encoder.vision_tower.transformer.layers.15.attention.q_proj.weight": "model-00001-of-00004.safetensors",
388
+ "encoder.vision_tower.transformer.layers.15.attention.v_proj.weight": "model-00001-of-00004.safetensors",
389
+ "encoder.vision_tower.transformer.layers.15.attention_norm.weight": "model-00001-of-00004.safetensors",
390
+ "encoder.vision_tower.transformer.layers.15.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
391
+ "encoder.vision_tower.transformer.layers.15.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
392
+ "encoder.vision_tower.transformer.layers.15.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
393
+ "encoder.vision_tower.transformer.layers.15.ffn_norm.weight": "model-00001-of-00004.safetensors",
394
+ "encoder.vision_tower.transformer.layers.16.attention.k_proj.weight": "model-00001-of-00004.safetensors",
395
+ "encoder.vision_tower.transformer.layers.16.attention.o_proj.weight": "model-00001-of-00004.safetensors",
396
+ "encoder.vision_tower.transformer.layers.16.attention.q_proj.weight": "model-00001-of-00004.safetensors",
397
+ "encoder.vision_tower.transformer.layers.16.attention.v_proj.weight": "model-00001-of-00004.safetensors",
398
+ "encoder.vision_tower.transformer.layers.16.attention_norm.weight": "model-00001-of-00004.safetensors",
399
+ "encoder.vision_tower.transformer.layers.16.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
400
+ "encoder.vision_tower.transformer.layers.16.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
401
+ "encoder.vision_tower.transformer.layers.16.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
402
+ "encoder.vision_tower.transformer.layers.16.ffn_norm.weight": "model-00001-of-00004.safetensors",
403
+ "encoder.vision_tower.transformer.layers.17.attention.k_proj.weight": "model-00001-of-00004.safetensors",
404
+ "encoder.vision_tower.transformer.layers.17.attention.o_proj.weight": "model-00001-of-00004.safetensors",
405
+ "encoder.vision_tower.transformer.layers.17.attention.q_proj.weight": "model-00001-of-00004.safetensors",
406
+ "encoder.vision_tower.transformer.layers.17.attention.v_proj.weight": "model-00001-of-00004.safetensors",
407
+ "encoder.vision_tower.transformer.layers.17.attention_norm.weight": "model-00001-of-00004.safetensors",
408
+ "encoder.vision_tower.transformer.layers.17.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
409
+ "encoder.vision_tower.transformer.layers.17.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
410
+ "encoder.vision_tower.transformer.layers.17.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
411
+ "encoder.vision_tower.transformer.layers.17.ffn_norm.weight": "model-00001-of-00004.safetensors",
412
+ "encoder.vision_tower.transformer.layers.18.attention.k_proj.weight": "model-00001-of-00004.safetensors",
413
+ "encoder.vision_tower.transformer.layers.18.attention.o_proj.weight": "model-00001-of-00004.safetensors",
414
+ "encoder.vision_tower.transformer.layers.18.attention.q_proj.weight": "model-00001-of-00004.safetensors",
415
+ "encoder.vision_tower.transformer.layers.18.attention.v_proj.weight": "model-00001-of-00004.safetensors",
416
+ "encoder.vision_tower.transformer.layers.18.attention_norm.weight": "model-00001-of-00004.safetensors",
417
+ "encoder.vision_tower.transformer.layers.18.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
418
+ "encoder.vision_tower.transformer.layers.18.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
419
+ "encoder.vision_tower.transformer.layers.18.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
420
+ "encoder.vision_tower.transformer.layers.18.ffn_norm.weight": "model-00001-of-00004.safetensors",
421
+ "encoder.vision_tower.transformer.layers.19.attention.k_proj.weight": "model-00001-of-00004.safetensors",
422
+ "encoder.vision_tower.transformer.layers.19.attention.o_proj.weight": "model-00001-of-00004.safetensors",
423
+ "encoder.vision_tower.transformer.layers.19.attention.q_proj.weight": "model-00001-of-00004.safetensors",
424
+ "encoder.vision_tower.transformer.layers.19.attention.v_proj.weight": "model-00001-of-00004.safetensors",
425
+ "encoder.vision_tower.transformer.layers.19.attention_norm.weight": "model-00001-of-00004.safetensors",
426
+ "encoder.vision_tower.transformer.layers.19.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
427
+ "encoder.vision_tower.transformer.layers.19.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
428
+ "encoder.vision_tower.transformer.layers.19.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
429
+ "encoder.vision_tower.transformer.layers.19.ffn_norm.weight": "model-00001-of-00004.safetensors",
430
+ "encoder.vision_tower.transformer.layers.2.attention.k_proj.weight": "model-00001-of-00004.safetensors",
431
+ "encoder.vision_tower.transformer.layers.2.attention.o_proj.weight": "model-00001-of-00004.safetensors",
432
+ "encoder.vision_tower.transformer.layers.2.attention.q_proj.weight": "model-00001-of-00004.safetensors",
433
+ "encoder.vision_tower.transformer.layers.2.attention.v_proj.weight": "model-00001-of-00004.safetensors",
434
+ "encoder.vision_tower.transformer.layers.2.attention_norm.weight": "model-00001-of-00004.safetensors",
435
+ "encoder.vision_tower.transformer.layers.2.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
436
+ "encoder.vision_tower.transformer.layers.2.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
437
+ "encoder.vision_tower.transformer.layers.2.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
438
+ "encoder.vision_tower.transformer.layers.2.ffn_norm.weight": "model-00001-of-00004.safetensors",
439
+ "encoder.vision_tower.transformer.layers.20.attention.k_proj.weight": "model-00001-of-00004.safetensors",
440
+ "encoder.vision_tower.transformer.layers.20.attention.o_proj.weight": "model-00001-of-00004.safetensors",
441
+ "encoder.vision_tower.transformer.layers.20.attention.q_proj.weight": "model-00001-of-00004.safetensors",
442
+ "encoder.vision_tower.transformer.layers.20.attention.v_proj.weight": "model-00001-of-00004.safetensors",
443
+ "encoder.vision_tower.transformer.layers.20.attention_norm.weight": "model-00001-of-00004.safetensors",
444
+ "encoder.vision_tower.transformer.layers.20.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
445
+ "encoder.vision_tower.transformer.layers.20.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
446
+ "encoder.vision_tower.transformer.layers.20.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
447
+ "encoder.vision_tower.transformer.layers.20.ffn_norm.weight": "model-00001-of-00004.safetensors",
448
+ "encoder.vision_tower.transformer.layers.21.attention.k_proj.weight": "model-00001-of-00004.safetensors",
449
+ "encoder.vision_tower.transformer.layers.21.attention.o_proj.weight": "model-00001-of-00004.safetensors",
450
+ "encoder.vision_tower.transformer.layers.21.attention.q_proj.weight": "model-00001-of-00004.safetensors",
451
+ "encoder.vision_tower.transformer.layers.21.attention.v_proj.weight": "model-00001-of-00004.safetensors",
452
+ "encoder.vision_tower.transformer.layers.21.attention_norm.weight": "model-00001-of-00004.safetensors",
453
+ "encoder.vision_tower.transformer.layers.21.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
454
+ "encoder.vision_tower.transformer.layers.21.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
455
+ "encoder.vision_tower.transformer.layers.21.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
456
+ "encoder.vision_tower.transformer.layers.21.ffn_norm.weight": "model-00001-of-00004.safetensors",
457
+ "encoder.vision_tower.transformer.layers.22.attention.k_proj.weight": "model-00001-of-00004.safetensors",
458
+ "encoder.vision_tower.transformer.layers.22.attention.o_proj.weight": "model-00001-of-00004.safetensors",
459
+ "encoder.vision_tower.transformer.layers.22.attention.q_proj.weight": "model-00001-of-00004.safetensors",
460
+ "encoder.vision_tower.transformer.layers.22.attention.v_proj.weight": "model-00001-of-00004.safetensors",
461
+ "encoder.vision_tower.transformer.layers.22.attention_norm.weight": "model-00001-of-00004.safetensors",
462
+ "encoder.vision_tower.transformer.layers.22.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
463
+ "encoder.vision_tower.transformer.layers.22.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
464
+ "encoder.vision_tower.transformer.layers.22.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
465
+ "encoder.vision_tower.transformer.layers.22.ffn_norm.weight": "model-00001-of-00004.safetensors",
466
+ "encoder.vision_tower.transformer.layers.23.attention.k_proj.weight": "model-00001-of-00004.safetensors",
467
+ "encoder.vision_tower.transformer.layers.23.attention.o_proj.weight": "model-00001-of-00004.safetensors",
468
+ "encoder.vision_tower.transformer.layers.23.attention.q_proj.weight": "model-00001-of-00004.safetensors",
469
+ "encoder.vision_tower.transformer.layers.23.attention.v_proj.weight": "model-00001-of-00004.safetensors",
470
+ "encoder.vision_tower.transformer.layers.23.attention_norm.weight": "model-00001-of-00004.safetensors",
471
+ "encoder.vision_tower.transformer.layers.23.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
472
+ "encoder.vision_tower.transformer.layers.23.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
473
+ "encoder.vision_tower.transformer.layers.23.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
474
+ "encoder.vision_tower.transformer.layers.23.ffn_norm.weight": "model-00001-of-00004.safetensors",
475
+ "encoder.vision_tower.transformer.layers.3.attention.k_proj.weight": "model-00001-of-00004.safetensors",
476
+ "encoder.vision_tower.transformer.layers.3.attention.o_proj.weight": "model-00001-of-00004.safetensors",
477
+ "encoder.vision_tower.transformer.layers.3.attention.q_proj.weight": "model-00001-of-00004.safetensors",
478
+ "encoder.vision_tower.transformer.layers.3.attention.v_proj.weight": "model-00001-of-00004.safetensors",
479
+ "encoder.vision_tower.transformer.layers.3.attention_norm.weight": "model-00001-of-00004.safetensors",
480
+ "encoder.vision_tower.transformer.layers.3.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
481
+ "encoder.vision_tower.transformer.layers.3.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
482
+ "encoder.vision_tower.transformer.layers.3.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
483
+ "encoder.vision_tower.transformer.layers.3.ffn_norm.weight": "model-00001-of-00004.safetensors",
484
+ "encoder.vision_tower.transformer.layers.4.attention.k_proj.weight": "model-00001-of-00004.safetensors",
485
+ "encoder.vision_tower.transformer.layers.4.attention.o_proj.weight": "model-00001-of-00004.safetensors",
486
+ "encoder.vision_tower.transformer.layers.4.attention.q_proj.weight": "model-00001-of-00004.safetensors",
487
+ "encoder.vision_tower.transformer.layers.4.attention.v_proj.weight": "model-00001-of-00004.safetensors",
488
+ "encoder.vision_tower.transformer.layers.4.attention_norm.weight": "model-00001-of-00004.safetensors",
489
+ "encoder.vision_tower.transformer.layers.4.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
490
+ "encoder.vision_tower.transformer.layers.4.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
491
+ "encoder.vision_tower.transformer.layers.4.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
492
+ "encoder.vision_tower.transformer.layers.4.ffn_norm.weight": "model-00001-of-00004.safetensors",
493
+ "encoder.vision_tower.transformer.layers.5.attention.k_proj.weight": "model-00001-of-00004.safetensors",
494
+ "encoder.vision_tower.transformer.layers.5.attention.o_proj.weight": "model-00001-of-00004.safetensors",
495
+ "encoder.vision_tower.transformer.layers.5.attention.q_proj.weight": "model-00001-of-00004.safetensors",
496
+ "encoder.vision_tower.transformer.layers.5.attention.v_proj.weight": "model-00001-of-00004.safetensors",
497
+ "encoder.vision_tower.transformer.layers.5.attention_norm.weight": "model-00001-of-00004.safetensors",
498
+ "encoder.vision_tower.transformer.layers.5.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
499
+ "encoder.vision_tower.transformer.layers.5.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
500
+ "encoder.vision_tower.transformer.layers.5.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
501
+ "encoder.vision_tower.transformer.layers.5.ffn_norm.weight": "model-00001-of-00004.safetensors",
502
+ "encoder.vision_tower.transformer.layers.6.attention.k_proj.weight": "model-00001-of-00004.safetensors",
503
+ "encoder.vision_tower.transformer.layers.6.attention.o_proj.weight": "model-00001-of-00004.safetensors",
504
+ "encoder.vision_tower.transformer.layers.6.attention.q_proj.weight": "model-00001-of-00004.safetensors",
505
+ "encoder.vision_tower.transformer.layers.6.attention.v_proj.weight": "model-00001-of-00004.safetensors",
506
+ "encoder.vision_tower.transformer.layers.6.attention_norm.weight": "model-00001-of-00004.safetensors",
507
+ "encoder.vision_tower.transformer.layers.6.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
508
+ "encoder.vision_tower.transformer.layers.6.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
509
+ "encoder.vision_tower.transformer.layers.6.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
510
+ "encoder.vision_tower.transformer.layers.6.ffn_norm.weight": "model-00001-of-00004.safetensors",
511
+ "encoder.vision_tower.transformer.layers.7.attention.k_proj.weight": "model-00001-of-00004.safetensors",
512
+ "encoder.vision_tower.transformer.layers.7.attention.o_proj.weight": "model-00001-of-00004.safetensors",
513
+ "encoder.vision_tower.transformer.layers.7.attention.q_proj.weight": "model-00001-of-00004.safetensors",
514
+ "encoder.vision_tower.transformer.layers.7.attention.v_proj.weight": "model-00001-of-00004.safetensors",
515
+ "encoder.vision_tower.transformer.layers.7.attention_norm.weight": "model-00001-of-00004.safetensors",
516
+ "encoder.vision_tower.transformer.layers.7.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
517
+ "encoder.vision_tower.transformer.layers.7.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
518
+ "encoder.vision_tower.transformer.layers.7.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
519
+ "encoder.vision_tower.transformer.layers.7.ffn_norm.weight": "model-00001-of-00004.safetensors",
520
+ "encoder.vision_tower.transformer.layers.8.attention.k_proj.weight": "model-00001-of-00004.safetensors",
521
+ "encoder.vision_tower.transformer.layers.8.attention.o_proj.weight": "model-00001-of-00004.safetensors",
522
+ "encoder.vision_tower.transformer.layers.8.attention.q_proj.weight": "model-00001-of-00004.safetensors",
523
+ "encoder.vision_tower.transformer.layers.8.attention.v_proj.weight": "model-00001-of-00004.safetensors",
524
+ "encoder.vision_tower.transformer.layers.8.attention_norm.weight": "model-00001-of-00004.safetensors",
525
+ "encoder.vision_tower.transformer.layers.8.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
526
+ "encoder.vision_tower.transformer.layers.8.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
527
+ "encoder.vision_tower.transformer.layers.8.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
528
+ "encoder.vision_tower.transformer.layers.8.ffn_norm.weight": "model-00001-of-00004.safetensors",
529
+ "encoder.vision_tower.transformer.layers.9.attention.k_proj.weight": "model-00001-of-00004.safetensors",
530
+ "encoder.vision_tower.transformer.layers.9.attention.o_proj.weight": "model-00001-of-00004.safetensors",
531
+ "encoder.vision_tower.transformer.layers.9.attention.q_proj.weight": "model-00001-of-00004.safetensors",
532
+ "encoder.vision_tower.transformer.layers.9.attention.v_proj.weight": "model-00001-of-00004.safetensors",
533
+ "encoder.vision_tower.transformer.layers.9.attention_norm.weight": "model-00001-of-00004.safetensors",
534
+ "encoder.vision_tower.transformer.layers.9.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
535
+ "encoder.vision_tower.transformer.layers.9.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
536
+ "encoder.vision_tower.transformer.layers.9.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
537
+ "encoder.vision_tower.transformer.layers.9.ffn_norm.weight": "model-00001-of-00004.safetensors"
538
+ }
539
+ }
model_cards/bias.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Field | Response
2
+ :---------------------------------------------------------------------------------------------------|:---------------
3
+ Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | [None]
4
+ Measures taken to mitigate against unwanted bias: | [None]
model_cards/explainability.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Field | Response
2
+ :------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------
3
+ Intended Task/Domain: | Text generation
4
+ Model Type: | Transformer
5
+ Intended Users: | Generative AI creators working with conversational AI models.
6
+ Output: | Text (Responds to posed question, Stateful - remembers previous answers)
7
+ Describe how the model works: | Text input is encoded into tokens and passed into a transformer-based language model, which returns a text response.
8
+ Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable
9
+ Technical Limitations & Mitigation: | The model cannot perform long-horizon reasoning and tool calling.
10
+ Verified to have met prescribed NVIDIA quality standards: | Yes
11
+ Performance Metrics: | Accuracy, Latency, Throughput
12
+ Potential Known Risks: | In some instances, the model may think too long and struggle to derive final answers. The model's output can generate all forms of text, including what may be considered toxic, offensive, or indecent.
13
+ Licensing: | nvidia-open-model-license.
model_cards/privacy.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Field | Response
2
+ :----------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------
3
+ Generatable or reverse engineerable personal data? | [No]
4
+ Personal data used to create this model? | [No]
5
+ Was consent obtained for any personal data used? | [Not Applicable]
6
+ How often is dataset reviewed? | [During dataset creation, model training, evaluation, and the prerelease phase.]
7
+ Was data from user interactions with the AI model (e.g. user input and prompts) used to train the model? | [Yes]
8
+ Is there provenance for all datasets used in training? | Yes
9
+ Does data labeling (annotation, metadata) comply with privacy laws? | Yes
10
+ Is data compliant with data subject requests for data correction or removal, if such a request was made? | Not Applicable.
11
+ Applicable Privacy Policy | https://www.nvidia.com/en-us/about-nvidia/privacy-policy/
model_cards/safety.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Field | Response
2
+ :---------------------------------------------------|:----------------------------------
3
+ Model Application Field(s): | [Media & Entertainment].
4
+ Describe the life critical impact (if present). | Not Applicable
5
+ Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to.
6
+ Use Case Restrictions: | Abide by nvidia-open-model-license.
modeling_ministral.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from transformers.utils.generic import check_model_inputs
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ # from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_layers import (
17
+ GenericForQuestionAnswering,
18
+ GenericForSequenceClassification,
19
+ GenericForTokenClassification,
20
+ GradientCheckpointingLayer,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
+ from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel
28
+ from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig
29
+ # from transformers.utils.generic import maybe_autocast
30
+ from .configuration_nemotron_labs_diffusion_vlm import NemotronLabsDiffusionVLMConfig
31
+
32
+ #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
33
+
34
+
35
+ class Ministral3PatchMerger(nn.Module):
36
+ """
37
+ Learned merging of spatial_merge_size ** 2 patches
38
+ """
39
+
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ self.config = config
43
+
44
+ hidden_size = config.vision_config.hidden_size
45
+ self.spatial_merge_size = config.spatial_merge_size
46
+ self.patch_size = self.config.vision_config.patch_size
47
+ self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
48
+
49
+ def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
50
+ image_sizes = [
51
+ (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
52
+ ]
53
+
54
+ tokens_per_image = [h * w for h, w in image_sizes]
55
+ d = image_features.shape[-1]
56
+
57
+ permuted_tensor = []
58
+ for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
59
+ # Reshape image_tokens into a 2D grid
60
+ h, w = image_sizes[image_index]
61
+ image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
62
+ grid = torch.nn.functional.unfold(
63
+ image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
64
+ )
65
+ grid = grid.view(d * self.spatial_merge_size**2, -1).t()
66
+ permuted_tensor.append(grid)
67
+
68
+ image_features = torch.cat(permuted_tensor, dim=0)
69
+ image_features = self.merging_layer(image_features)
70
+ return image_features
71
+
72
+
73
+
74
+ class Ministral3MultiModalProjector(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.norm = Ministral3RMSNorm(config.vision_config.hidden_size, eps=config.rms_norm_eps)
78
+ self.patch_merger = Ministral3PatchMerger(config)
79
+ # We have hidden_size * the number of vision feature layers
80
+ self.num_feature_layers = (
81
+ 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
82
+ )
83
+ self.linear_1 = nn.Linear(
84
+ config.vision_config.hidden_size * self.num_feature_layers,
85
+ config.hidden_size,
86
+ bias=config.multimodal_projector_bias,
87
+ )
88
+ self.act = ACT2FN[config.projector_hidden_act]
89
+ self.linear_2 = nn.Linear(
90
+ config.hidden_size, config.hidden_size, bias=config.multimodal_projector_bias
91
+ )
92
+
93
+ def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
94
+ image_features = self.norm(image_features)
95
+ image_features = self.patch_merger(image_features, image_sizes)
96
+ hidden_states = self.linear_1(image_features)
97
+ hidden_states = self.act(hidden_states)
98
+ hidden_states = self.linear_2(hidden_states)
99
+ return hidden_states
100
+
101
+
102
+ def rotate_half(x):
103
+ """Rotates half the hidden dims of the input."""
104
+ x1 = x[..., : x.shape[-1] // 2]
105
+ x2 = x[..., x.shape[-1] // 2 :]
106
+ return torch.cat((-x2, x1), dim=-1)
107
+
108
+ # @use_kernel_func_from_hub("rotary_pos_emb")
109
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
110
+ """Applies Rotary Position Embedding to the query and key tensors.
111
+
112
+ Args:
113
+ q (`torch.Tensor`): The query tensor.
114
+ k (`torch.Tensor`): The key tensor.
115
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
116
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
117
+ position_ids (`torch.Tensor`, *optional*):
118
+ Deprecated and unused.
119
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
120
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
121
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
122
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
123
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
124
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
125
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
126
+ Returns:
127
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
128
+ """
129
+ cos = cos.unsqueeze(unsqueeze_dim)
130
+ sin = sin.unsqueeze(unsqueeze_dim)
131
+ q_embed = (q * cos) + (rotate_half(q) * sin)
132
+ k_embed = (k * cos) + (rotate_half(k) * sin)
133
+ return q_embed, k_embed
134
+
135
+
136
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
137
+ """
138
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
139
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
140
+ """
141
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
142
+ if n_rep == 1:
143
+ return hidden_states
144
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
145
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
146
+
147
+
148
+ def eager_attention_forward(
149
+ module: nn.Module,
150
+ query: torch.Tensor,
151
+ key: torch.Tensor,
152
+ value: torch.Tensor,
153
+ attention_mask: Optional[torch.Tensor],
154
+ scaling: float,
155
+ dropout: float = 0.0,
156
+ **kwargs: Unpack[TransformersKwargs],
157
+ ):
158
+ key_states = repeat_kv(key, module.num_key_value_groups)
159
+ value_states = repeat_kv(value, module.num_key_value_groups)
160
+
161
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
162
+ if attention_mask is not None:
163
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
164
+ attn_weights = attn_weights + causal_mask
165
+
166
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
167
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
168
+ attn_output = torch.matmul(attn_weights, value_states)
169
+ attn_output = attn_output.transpose(1, 2).contiguous()
170
+
171
+ return attn_output, attn_weights
172
+
173
+
174
+ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
175
+ scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
176
+ return scaling.unsqueeze(-1)
177
+
178
+
179
+ # @use_kernelized_func(apply_rotary_pos_emb)
180
+ class Ministral3Attention(nn.Module):
181
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
182
+
183
+ def __init__(self, config: NemotronLabsDiffusionVLMConfig, layer_idx: int):
184
+ super().__init__()
185
+ self.config = config
186
+ self.layer_idx = layer_idx
187
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
188
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
189
+ self.scaling = self.head_dim**-0.5
190
+ self.attention_dropout = config.attention_dropout
191
+ self.is_causal = True
192
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
193
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
194
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
195
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
196
+
197
+ self.diffusion_lm = config.diffusion_lm
198
+
199
+ def forward(
200
+ self,
201
+ hidden_states: torch.Tensor,
202
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
203
+ attention_mask: Optional[torch.Tensor],
204
+ past_key_values: Optional[Cache] = None,
205
+ cache_position: Optional[torch.LongTensor] = None,
206
+ use_cache: Optional[bool] = False,
207
+ **kwargs: Unpack[FlashAttentionKwargs],
208
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
209
+ input_shape = hidden_states.shape[:-1]
210
+ hidden_shape = (*input_shape, -1, self.head_dim)
211
+
212
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
213
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
214
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
215
+
216
+ cos, sin = position_embeddings
217
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
218
+ query_states = query_states * _get_llama_4_attn_scale(
219
+ cache_position,
220
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
221
+ self.config.rope_parameters.get("original_max_position_embeddings"),
222
+ ).to(query_states.dtype)
223
+
224
+ if past_key_values is not None:
225
+ if use_cache:
226
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
227
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
228
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
229
+ else: ## if use_cache == False, do not update cache
230
+ old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
231
+ key_states = torch.cat([old_k, key_states], dim=-2)
232
+ value_states = torch.cat([old_v, value_states], dim=-2)
233
+
234
+ attention_interface: Callable = eager_attention_forward
235
+ if self.config._attn_implementation != "eager":
236
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
237
+
238
+ if self.diffusion_lm:
239
+ attn_output, attn_weights = attention_interface(
240
+ self,
241
+ query_states,
242
+ key_states,
243
+ value_states,
244
+ None,
245
+ dropout=0.0 if not self.training else self.attention_dropout,
246
+ scaling=self.scaling,
247
+ is_causal=False,
248
+ **kwargs,
249
+ )
250
+
251
+ else:
252
+ attn_output, attn_weights = attention_interface(
253
+ self,
254
+ query_states,
255
+ key_states,
256
+ value_states,
257
+ attention_mask,
258
+ dropout=0.0 if not self.training else self.attention_dropout,
259
+ scaling=self.scaling,
260
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
261
+ **kwargs,
262
+ )
263
+
264
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
265
+ attn_output = self.o_proj(attn_output)
266
+ return attn_output, attn_weights
267
+
268
+
269
+ class Ministral3MLP(nn.Module):
270
+ def __init__(self, config):
271
+ super().__init__()
272
+ self.config = config
273
+ self.hidden_size = config.hidden_size
274
+ self.intermediate_size = config.intermediate_size
275
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
276
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
277
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
278
+ self.act_fn = ACT2FN[config.hidden_act]
279
+
280
+ def forward(self, x):
281
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
282
+ return down_proj
283
+
284
+
285
+ @use_kernel_forward_from_hub("RMSNorm")
286
+ class Ministral3RMSNorm(nn.Module):
287
+ def __init__(self, hidden_size, eps=1e-6):
288
+ """
289
+ Ministral3RMSNorm is equivalent to T5LayerNorm
290
+ """
291
+ super().__init__()
292
+ self.weight = nn.Parameter(torch.ones(hidden_size))
293
+ self.variance_epsilon = eps
294
+
295
+ def forward(self, hidden_states):
296
+ input_dtype = hidden_states.dtype
297
+ hidden_states = hidden_states.to(torch.float32)
298
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
299
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
300
+ return self.weight * hidden_states.to(input_dtype)
301
+
302
+ def extra_repr(self):
303
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
304
+
305
+
306
+ class Ministral3DecoderLayer(GradientCheckpointingLayer):
307
+ def __init__(self, config: NemotronLabsDiffusionVLMConfig, layer_idx: int):
308
+ super().__init__()
309
+ self.hidden_size = config.hidden_size
310
+
311
+ if hasattr(config, 'attn_class'):
312
+ attn_class = config.attn_class
313
+ else:
314
+ attn_class = Ministral3Attention
315
+
316
+ self.self_attn = attn_class(config=config, layer_idx=layer_idx)
317
+ self.mlp = Ministral3MLP(config)
318
+ self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
319
+ self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
320
+
321
+ def forward(
322
+ self,
323
+ hidden_states: torch.Tensor,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ position_ids: Optional[torch.LongTensor] = None,
326
+ past_key_values: Optional[Cache] = None,
327
+ use_cache: Optional[bool] = False,
328
+ cache_position: Optional[torch.LongTensor] = None,
329
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
330
+ **kwargs: Unpack[TransformersKwargs],
331
+ ) -> torch.Tensor:
332
+ residual = hidden_states
333
+ hidden_states = self.input_layernorm(hidden_states)
334
+ # Self Attention
335
+ hidden_states, _ = self.self_attn(
336
+ hidden_states=hidden_states,
337
+ attention_mask=attention_mask,
338
+ position_ids=position_ids,
339
+ past_key_values=past_key_values,
340
+ use_cache=use_cache,
341
+ cache_position=cache_position,
342
+ position_embeddings=position_embeddings,
343
+ **kwargs,
344
+ )
345
+ hidden_states = residual + hidden_states
346
+
347
+ # Fully Connected
348
+ residual = hidden_states
349
+ hidden_states = self.post_attention_layernorm(hidden_states)
350
+ hidden_states = self.mlp(hidden_states)
351
+ hidden_states = residual + hidden_states
352
+ return hidden_states
353
+
354
+
355
+ @auto_docstring
356
+ class Ministral3PreTrainedModel(PreTrainedModel):
357
+ config: NemotronLabsDiffusionVLMConfig
358
+ base_model_prefix = "model"
359
+ supports_gradient_checkpointing = True
360
+ # Ministral3RMSNorm must be a separate FSDP unit to avoid weight sharded to size 0 on some ranks
361
+ _no_split_modules = ["Ministral3DecoderLayer", "Ministral3RMSNorm"]
362
+ _skip_keys_device_placement = ["past_key_values"]
363
+ _supports_flash_attn = True
364
+ _supports_sdpa = True
365
+ _supports_flex_attn = True
366
+
367
+ _can_compile_fullgraph = True
368
+ _supports_attention_backend = True
369
+ _can_record_outputs = {
370
+ "hidden_states": Ministral3DecoderLayer,
371
+ "attentions": Ministral3Attention,
372
+ }
373
+
374
+
375
+ class Ministral3RotaryEmbedding(nn.Module):
376
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
377
+
378
+ def __init__(self, config: NemotronLabsDiffusionVLMConfig, device=None):
379
+ super().__init__()
380
+ self.max_seq_len_cached = config.max_position_embeddings
381
+ self.original_max_seq_len = config.max_position_embeddings
382
+
383
+ self.config = config
384
+
385
+ self.rope_type = self.config.rope_parameters["rope_type"]
386
+ rope_init_fn: Callable = self.compute_default_rope_parameters
387
+ if self.rope_type != "default":
388
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
389
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
390
+
391
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
392
+ self.original_inv_freq = inv_freq
393
+
394
+
395
+ @staticmethod
396
+ def compute_default_rope_parameters(
397
+ config: Optional[NemotronLabsDiffusionVLMConfig] = None,
398
+ device: Optional["torch.device"] = None,
399
+ seq_len: Optional[int] = None,
400
+ ) -> tuple["torch.Tensor", float]:
401
+ """
402
+ Computes the inverse frequencies according to the original RoPE implementation
403
+ Args:
404
+ config ([`~transformers.PreTrainedConfig`]):
405
+ The model configuration.
406
+ device (`torch.device`):
407
+ The device to use for initialization of the inverse frequencies.
408
+ seq_len (`int`, *optional*):
409
+ The current sequence length. Unused for this type of RoPE.
410
+ Returns:
411
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
412
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
413
+ """
414
+ base = config.rope_parameters["rope_theta"]
415
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
416
+
417
+ attention_factor = 1.0 # Unused in this type of RoPE
418
+
419
+ # Compute the inverse frequencies
420
+ inv_freq = 1.0 / (
421
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
422
+ )
423
+ return inv_freq, attention_factor
424
+
425
+ @torch.no_grad()
426
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
427
+ def forward(self, x, position_ids):
428
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
429
+ position_ids_expanded = position_ids[:, None, :].float()
430
+
431
+ # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
432
+ # with maybe_autocast(device_type=device_type, enabled=False): # Force float32
433
+
434
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
435
+ emb = torch.cat((freqs, freqs), dim=-1)
436
+ cos = emb.cos() * self.attention_scaling
437
+ sin = emb.sin() * self.attention_scaling
438
+
439
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
440
+
441
+
442
+ @auto_docstring
443
+ class Ministral3Model(Ministral3PreTrainedModel):
444
+ def __init__(self, config: NemotronLabsDiffusionVLMConfig):
445
+ super().__init__(config)
446
+ vision_config = config.vision_config
447
+ if not isinstance(vision_config, PixtralVisionConfig):
448
+ vision_config = PixtralVisionConfig(**vision_config) if isinstance(vision_config, dict) else PixtralVisionConfig(**vars(vision_config))
449
+ config.vision_config = vision_config
450
+
451
+ self.vision_tower = PixtralVisionModel(vision_config)
452
+ self.multi_modal_projector = Ministral3MultiModalProjector(config)
453
+ self.padding_idx = config.pad_token_id
454
+ self.vocab_size = config.vocab_size
455
+
456
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
457
+ self.layers = nn.ModuleList(
458
+ [Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
459
+ )
460
+ self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.rotary_emb = Ministral3RotaryEmbedding(config=config)
462
+ self.gradient_checkpointing = False
463
+
464
+ # Initialize weights and apply final processing
465
+ self.post_init()
466
+
467
+ @check_model_inputs
468
+ @auto_docstring
469
+ def forward(
470
+ self,
471
+ input_ids: Optional[torch.LongTensor] = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ position_ids: Optional[torch.LongTensor] = None,
474
+ past_key_values: Optional[Cache] = None,
475
+ inputs_embeds: Optional[torch.FloatTensor] = None,
476
+ use_cache: Optional[bool] = None,
477
+ cache_position: Optional[torch.LongTensor] = None,
478
+ **kwargs: Unpack[TransformersKwargs],
479
+ ) -> BaseModelOutputWithPast:
480
+ if (input_ids is None) ^ (inputs_embeds is not None):
481
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
482
+
483
+ if inputs_embeds is None:
484
+ inputs_embeds = self.embed_tokens(input_ids)
485
+
486
+ if use_cache and past_key_values is None:
487
+ # past_key_values = DynamicCache(config=self.config)
488
+ past_key_values = DynamicCache()
489
+
490
+ if cache_position is None:
491
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
492
+ cache_position = torch.arange(
493
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
494
+ )
495
+
496
+ if position_ids is None:
497
+ position_ids = cache_position.unsqueeze(0)
498
+
499
+ if kwargs.get("use_causal_mask", False):
500
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
501
+ causal_mask = mask_function(
502
+ config=self.config,
503
+ input_embeds=inputs_embeds,
504
+ attention_mask=attention_mask,
505
+ cache_position=cache_position,
506
+ past_key_values=past_key_values,
507
+ position_ids=position_ids,
508
+ )
509
+
510
+ else:
511
+ causal_mask = None
512
+
513
+ hidden_states = inputs_embeds
514
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
515
+
516
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
517
+ hidden_states = decoder_layer(
518
+ hidden_states,
519
+ attention_mask=causal_mask,
520
+ position_ids=position_ids,
521
+ past_key_values=past_key_values,
522
+ use_cache=use_cache,
523
+ cache_position=cache_position,
524
+ position_embeddings=position_embeddings,
525
+ **kwargs,
526
+ )
527
+ hidden_states = self.norm(hidden_states)
528
+ return BaseModelOutputWithPast(
529
+ last_hidden_state=hidden_states,
530
+ past_key_values=past_key_values if use_cache else None,
531
+ )
532
+
533
+
534
+ @auto_docstring
535
+ class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
536
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
537
+ _tp_plan = {"lm_head": "colwise_rep"}
538
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
539
+
540
+ def __init__(self, config):
541
+ super().__init__(config)
542
+ self.model = Ministral3Model(config)
543
+ self.vocab_size = config.vocab_size
544
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
545
+
546
+ # Initialize weights and apply final processing
547
+ self.post_init()
548
+
549
+ @can_return_tuple
550
+ @auto_docstring
551
+ def forward(
552
+ self,
553
+ input_ids: Optional[torch.LongTensor] = None,
554
+ attention_mask: Optional[torch.Tensor] = None,
555
+ position_ids: Optional[torch.LongTensor] = None,
556
+ past_key_values: Optional[Cache] = None,
557
+ inputs_embeds: Optional[torch.FloatTensor] = None,
558
+ labels: Optional[torch.LongTensor] = None,
559
+ use_cache: Optional[bool] = None,
560
+ cache_position: Optional[torch.LongTensor] = None,
561
+ logits_to_keep: Union[int, torch.Tensor] = 0,
562
+ **kwargs: Unpack[TransformersKwargs],
563
+ ) -> CausalLMOutputWithPast:
564
+ r"""
565
+ Example:
566
+
567
+ ```python
568
+ >>> from transformers import AutoTokenizer, Ministral3ForCausalLM
569
+
570
+ >>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
571
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
572
+
573
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
574
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
575
+
576
+ >>> # Generate
577
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
578
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
579
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
580
+ ```"""
581
+ outputs: BaseModelOutputWithPast = self.model(
582
+ input_ids=input_ids,
583
+ attention_mask=attention_mask,
584
+ position_ids=position_ids,
585
+ past_key_values=past_key_values,
586
+ inputs_embeds=inputs_embeds,
587
+ use_cache=use_cache,
588
+ cache_position=cache_position,
589
+ **kwargs,
590
+ )
591
+
592
+ hidden_states = outputs.last_hidden_state
593
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
594
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
595
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
596
+
597
+ loss = None
598
+ if labels is not None:
599
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
600
+
601
+ return CausalLMOutputWithPast(
602
+ loss=loss,
603
+ logits=logits,
604
+ past_key_values=outputs.past_key_values,
605
+ hidden_states=outputs.hidden_states,
606
+ attentions=outputs.attentions,
607
+ )
608
+
609
+
610
+ class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
611
+ pass
612
+
613
+
614
+ class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
615
+ pass
616
+
617
+
618
+ class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
619
+ pass
620
+
621
+
622
+ __all__ = [
623
+ "Ministral3ForCausalLM",
624
+ "Ministral3ForQuestionAnswering",
625
+ "Ministral3Model",
626
+ "Ministral3PreTrainedModel",
627
+ "Ministral3ForSequenceClassification",
628
+ "Ministral3ForTokenClassification",
629
+ ]
modeling_nemotron_labs_diffusion_vlm.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, Tuple, Union
4
+ import random
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
17
+
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.loss.loss_utils import LOSS_MAPPING
28
+
29
+ import math
30
+
31
+ from .chat_utils import generate_with_prefix_cache_block_diff
32
+ from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
33
+ from .configuration_nemotron_labs_diffusion_vlm import NemotronLabsDiffusionVLMConfig
34
+
35
+
36
+ @dataclass
37
+ class NemotronLabsDiffusionVLMOutputWithPast(ModelOutput):
38
+ loss: torch.FloatTensor | None = None
39
+ logits: torch.FloatTensor | None = None
40
+ causal_logits: torch.FloatTensor | None = None
41
+ past_key_values: Cache | None = None
42
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
43
+ attentions: tuple[torch.FloatTensor, ...] | None = None
44
+
45
+
46
+ # @torch.compile(dynamic=True, mode="reduce-overhead")
47
+ # @torch.compile(mode="default")
48
+ # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
49
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
50
+ def fused_flex_attention(q, k, v, block_mask=None):
51
+ return flex_attention(q, k, v, block_mask=block_mask)
52
+
53
+
54
+ def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
55
+ """Crop a DynamicCache to max_length, compatible with both old and new transformers."""
56
+ if hasattr(past_key_values, 'crop'):
57
+ past_key_values.crop(max_length)
58
+ else:
59
+ for layer_idx in range(len(past_key_values)):
60
+ past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
61
+ past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
62
+ past_key_values._seen_tokens = max_length
63
+
64
+
65
+ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
66
+ """After quadratic decoding, extract only draft tokens (first of each block) from cache."""
67
+ for layer_idx in range(len(past_key_values)):
68
+ if hasattr(past_key_values, 'layers'):
69
+ layer_cache = past_key_values.layers[layer_idx]
70
+ k, v = layer_cache.keys, layer_cache.values
71
+ else:
72
+ k = past_key_values.key_cache[layer_idx]
73
+ v = past_key_values.value_cache[layer_idx]
74
+
75
+ clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
76
+ clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
77
+ new_k = torch.cat([clean_k, draft_k], dim=2)
78
+ new_v = torch.cat([clean_v, draft_v], dim=2)
79
+
80
+ if hasattr(past_key_values, 'layers'):
81
+ layer_cache.keys = new_k
82
+ layer_cache.values = new_v
83
+ else:
84
+ past_key_values.key_cache[layer_idx] = new_k
85
+ past_key_values.value_cache[layer_idx] = new_v
86
+
87
+ past_key_values._seen_tokens = clean_len + block_length
88
+
89
+ # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
90
+ class NemotronLabsDiffusionVLMFlexAttention(Ministral3Attention):
91
+ def __init__(self, *args, **kwargs):
92
+ super().__init__(*args, **kwargs)
93
+
94
+ self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
95
+ self.block_size_orig = self.config.block_size
96
+
97
+ if self.config.dlm_paradigm == 'bidirectional':
98
+ self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
99
+ elif self.config.dlm_paradigm == 'autoregressive':
100
+ self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
101
+ elif self.config.dlm_paradigm == 'block_diff':
102
+ self.block_diff_mask = None
103
+ elif self.config.dlm_paradigm == 'sbd_block_diff':
104
+ self.sbd_block_diff_mask = None
105
+ else:
106
+ raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
107
+
108
+ self.block_size = self.block_size_orig
109
+ self.mode = self.config.dlm_paradigm
110
+ self._quadratic_block_mask = {}
111
+
112
+ import torch._dynamo.config as dcfg
113
+ dcfg.cache_size_limit = 512
114
+
115
+ def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
116
+ if block_length not in self._quadratic_block_mask:
117
+ draft_len = block_length * (block_length + 1)
118
+
119
+ def quadratic(b, h, q_idx, kv_idx):
120
+ first_clean = torch.logical_and(
121
+ kv_idx % (block_length + 1) == 0,
122
+ kv_idx < draft_len,
123
+ )
124
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
125
+ block_q = q_idx // (block_length + 1)
126
+ block_kv = kv_idx // (block_length + 1)
127
+ same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
128
+ same_block_except_first = torch.logical_and(
129
+ same_block,
130
+ q_idx % (block_length + 1) != 0,
131
+ )
132
+ draft_part = torch.logical_or(first_clean, same_block_except_first)
133
+ clean_part = kv_idx >= draft_len
134
+ return torch.logical_or(draft_part, clean_part)
135
+
136
+ block_mask = create_block_mask(
137
+ quadratic,
138
+ B=None,
139
+ H=None,
140
+ Q_LEN=draft_len,
141
+ KV_LEN=draft_len + self.config.max_position_embeddings,
142
+ device="cuda",
143
+ )
144
+
145
+ self._quadratic_block_mask[block_length] = block_mask
146
+
147
+ return self._quadratic_block_mask[block_length]
148
+
149
+ def set_attention_mode(self, mode, block_size=None):
150
+ self.mode = mode
151
+ self.block_size = block_size
152
+
153
+ def compute_block_mask(self, mode, q_len=None, block_size=None):
154
+
155
+ def bidirectional_mask(b, h, q, kv):
156
+ return (q >= kv) | (q < kv)
157
+
158
+ def autoregressive_mask(b, h, q, kv):
159
+ return (q >= kv)
160
+
161
+ def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
162
+ """
163
+ Constructs the specialized block diffusion attention mask for training
164
+ composed of three masks:
165
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
166
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
167
+ - **Block Causal Mask (M_BC)**: Attention to update x0
168
+ Args:
169
+ b, h: Batch and head indices (ignored for mask logic).
170
+ q_idx, kv_idx: Query and Key indices.
171
+ seq_len: Total sequence length.
172
+ block_size: Defines the block structure.
173
+ Returns:
174
+ A boolean attention mask.
175
+ """
176
+
177
+ # Indicate whether token belongs to xt or x0
178
+ x0_flag_q = (q_idx >= n)
179
+ x0_flag_kv = (kv_idx >= n)
180
+
181
+ # Compute block indices
182
+ block_q = torch.where(x0_flag_q == 1,
183
+ (q_idx - n) // block_size,
184
+ q_idx // block_size)
185
+ block_kv = torch.where(x0_flag_kv == 1,
186
+ (kv_idx - n) // block_size,
187
+ kv_idx // block_size)
188
+
189
+ # **1. Block Diagonal Mask (M_BD) **
190
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
191
+
192
+ # **2. Offset Block-Causal Mask (M_OBC) **
193
+ offset_block_causal = (
194
+ (block_q > block_kv)
195
+ & (x0_flag_kv == 1)
196
+ & (x0_flag_q == 0)
197
+ )
198
+
199
+ # **3. Block-Causal Mask (M_BC) **
200
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
201
+
202
+ # **4. Combine Masks **
203
+ return block_diagonal | offset_block_causal | block_causal
204
+
205
+
206
+ def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
207
+ x0_flag_q = (q_idx >= n)
208
+ x0_flag_kv = (kv_idx >= n)
209
+
210
+ # Compute block indices
211
+ block_q = torch.where(x0_flag_q == 1,
212
+ (q_idx - n) // block_size,
213
+ q_idx // block_size)
214
+ block_kv = torch.where(x0_flag_kv == 1,
215
+ (kv_idx - n) // block_size,
216
+ kv_idx // block_size)
217
+
218
+ # **1. Block Diagonal Mask (M_BD) **
219
+ block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
220
+
221
+ # **2. Offset Block-Causal Mask (M_OBC) **
222
+ offset_block_causal = (
223
+ (block_q > block_kv)
224
+ & (x0_flag_kv == 1)
225
+ & (x0_flag_q == 0)
226
+ )
227
+
228
+ # **3. Fully Causal Mask (M_BC) **
229
+ fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
230
+
231
+ # **4. Combine Masks **
232
+ return block_diagonal | offset_block_causal | fully_causal
233
+
234
+ if mode == 'bidirectional':
235
+ attn_mask = bidirectional_mask
236
+ elif mode == 'autoregressive':
237
+ attn_mask = autoregressive_mask
238
+ elif mode == 'block_diff':
239
+ assert block_size is not None
240
+ n = (q_len // 2) if q_len is not None else self.max_seq_length
241
+ attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, n)
242
+ elif mode == 'sbd_block_diff':
243
+ assert block_size is not None
244
+ n = (q_len // 2) if q_len is not None else self.max_seq_length
245
+ attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, n)
246
+ else:
247
+ raise ValueError(f"Unknown attention mode: {mode}")
248
+
249
+ if q_len is not None:
250
+ Q_LEN = q_len
251
+ else:
252
+ if mode in ['block_diff', 'sbd_block_diff']:
253
+ Q_LEN = self.max_seq_length * 2
254
+ else:
255
+ Q_LEN = self.max_seq_length
256
+
257
+ block_mask = create_block_mask(
258
+ attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
259
+ )
260
+
261
+ return block_mask
262
+
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
268
+ attention_mask: Optional[torch.Tensor],
269
+ past_key_values: Optional[Cache] = None,
270
+ cache_position: Optional[torch.LongTensor] = None,
271
+ is_training: bool = True,
272
+ **kwargs: Unpack[FlashAttentionKwargs],
273
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
274
+ bsz, q_len, _ = hidden_states.size()
275
+ input_shape = hidden_states.shape[:-1]
276
+ hidden_shape = (*input_shape, -1, self.head_dim)
277
+
278
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
279
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
280
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
281
+
282
+ cos, sin = position_embeddings
283
+
284
+ if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
285
+ # Split query and key states in half along sequence length dimension
286
+ q1, q2 = query_states.chunk(2, dim=2)
287
+ k1, k2 = key_states.chunk(2, dim=2)
288
+
289
+ # Apply RoPE independently to each half
290
+ q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
291
+ q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
292
+
293
+ # Recombine the halves
294
+ query_states = torch.cat([q1, q2], dim=2)
295
+ key_states = torch.cat([k1, k2], dim=2)
296
+ else:
297
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
298
+
299
+ query_states = query_states * _get_llama_4_attn_scale(
300
+ cache_position,
301
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
302
+ self.config.rope_parameters.get("original_max_position_embeddings"),
303
+ ).to(query_states.dtype)
304
+
305
+ if past_key_values is not None:
306
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
307
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
308
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
309
+
310
+ self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
311
+ if self_spec_inference_mode is not None:
312
+ if self_spec_inference_mode == "quadratic":
313
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
314
+ if block_length is None:
315
+ raise ValueError("SBD quadratic decoding requires block_length in config.")
316
+ if past_key_values is not None:
317
+ seq_len = key_states.shape[2]
318
+ draft_len = block_length * (block_length + 1)
319
+
320
+ clean_keys = key_states[:, :, :-draft_len]
321
+ draft_keys = key_states[:, :, -draft_len:]
322
+ clean_values = value_states[:, :, :-draft_len]
323
+ draft_values = value_states[:, :, -draft_len:]
324
+ key_states = torch.cat([draft_keys, clean_keys], dim=2)
325
+ value_states = torch.cat([draft_values, clean_values], dim=2)
326
+
327
+ block_mask = self._get_sbd_inference_quadratic_decoding_block_mask(
328
+ block_length=block_length
329
+ )
330
+ block_mask.seq_lengths = (draft_len, seq_len)
331
+ else:
332
+ seq_len = query_states.shape[2]
333
+ draft_len = block_length * (block_length + 1)
334
+ clean_len = seq_len - draft_len
335
+
336
+ def _causal_mask(b, h, q_idx, kv_idx):
337
+ return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
338
+
339
+ def _draft2clean_mask(b, h, q_idx, kv_idx):
340
+ full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
341
+ first_clean = torch.logical_and(
342
+ q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
343
+ )
344
+ first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
345
+ return torch.logical_or(full_clean, first_clean)
346
+
347
+ def _draft_mask(b, h, q_idx, kv_idx):
348
+ block_q = (q_idx - clean_len) // (block_length + 1)
349
+ block_kv = (kv_idx - clean_len) // (block_length + 1)
350
+ quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
351
+ same_block = torch.logical_and(block_q == block_kv, quadrant)
352
+ same_block_except_first = torch.logical_and(
353
+ same_block,
354
+ (q_idx - clean_len) % (block_length + 1) != 0,
355
+ )
356
+ return torch.logical_and(block_q == block_kv, same_block_except_first)
357
+
358
+ mask = or_masks(_causal_mask, _draft2clean_mask)
359
+ mask = or_masks(mask, _draft_mask)
360
+
361
+ block_mask = create_block_mask(
362
+ mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
363
+ )
364
+
365
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
366
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
367
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
368
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
369
+ attn_output = self.o_proj(attn_output)
370
+ return attn_output, None
371
+
372
+ elif self_spec_inference_mode == "default":
373
+ block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
374
+ if block_length is None:
375
+ raise ValueError("SBD default decoding requires block_length in config.")
376
+ seq_len = query_states.shape[2]
377
+ prefix_len = seq_len - block_length
378
+
379
+ def _clean_q_mask(b, h, q_idx, kv_idx):
380
+ return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
381
+
382
+ def _noisy_q_mask(b, h, q_idx, kv_idx):
383
+ return q_idx >= prefix_len
384
+
385
+ block_mask = create_block_mask(
386
+ or_masks(_clean_q_mask, _noisy_q_mask),
387
+ B=None,
388
+ H=None,
389
+ Q_LEN=seq_len,
390
+ KV_LEN=seq_len,
391
+ )
392
+
393
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
394
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
395
+ attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
396
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
397
+ attn_output = self.o_proj(attn_output)
398
+ return attn_output, None
399
+
400
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
401
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
402
+
403
+ if self.mode == 'bidirectional':
404
+ if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
405
+ block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
406
+ else:
407
+ block_mask = self.bidirectional_mask
408
+
409
+ elif self.mode == 'autoregressive':
410
+ if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
411
+ block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
412
+ else:
413
+ block_mask = self.autoregressive_mask
414
+
415
+ elif self.mode == 'block_diff':
416
+ if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
417
+ block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
418
+ else:
419
+ block_mask = self.block_diff_mask
420
+ elif self.mode == 'sbd_block_diff':
421
+ if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
422
+ block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
423
+ else:
424
+ block_mask = self.sbd_block_diff_mask
425
+ else:
426
+ raise ValueError(f"Unknown attention mode: {self.mode}")
427
+
428
+ attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
429
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
430
+
431
+ attn_output = self.o_proj(attn_output)
432
+
433
+ return attn_output, None
434
+
435
+
436
+ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
437
+ """Return a Bool mask of length len(log_w) with exactly k True."""
438
+ g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
439
+ topk = torch.topk(log_w + g, k).indices
440
+ mask = torch.zeros_like(log_w, dtype=torch.bool)
441
+ mask[topk] = True
442
+ return mask
443
+
444
+
445
+ class NemotronLabsDiffusionVLMModel(Ministral3PreTrainedModel, GenerationMixin):
446
+ """
447
+ A single model with:
448
+ - a bidirectional encoder + diffusion‐LM head over A
449
+ - a causal decoder + LM head over B, conditioned on F_A
450
+ """
451
+
452
+ def __init__(self, config: NemotronLabsDiffusionVLMConfig):
453
+ super().__init__(config)
454
+
455
+ self.mask_token_id = config.mask_token_id
456
+
457
+ diffusion_config = copy.deepcopy(config)
458
+ diffusion_config.diffusion_lm = True
459
+
460
+ use_flex = getattr(config, 'enable_self_spec', False)
461
+
462
+ if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
463
+ diffusion_config.attn_class = NemotronLabsDiffusionVLMFlexAttention
464
+ elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
465
+ diffusion_config.attn_class = NemotronLabsDiffusionVLMFlexAttention if use_flex else Ministral3Attention
466
+ if config.dlm_paradigm == 'autoregressive':
467
+ diffusion_config.diffusion_lm = False
468
+ else:
469
+ raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
470
+
471
+ self.encoder = Ministral3Model(diffusion_config)
472
+ self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
473
+ self.vocab_size = config.vocab_size
474
+
475
+ self.current_iter_ratio = None
476
+ self.mdm_loss_function = LOSS_MAPPING['ForMaskedLM']
477
+ self.causal_loss_function = LOSS_MAPPING['ForCausalLM']
478
+
479
+ self.post_init()
480
+
481
+
482
+ def get_input_embeddings(self):
483
+ return self.encoder.embed_tokens
484
+
485
+ def set_input_embeddings(self, value):
486
+ self.encoder.embed_tokens = value
487
+
488
+ def get_output_embeddings(self):
489
+ return self.diffusion_head
490
+
491
+ def set_output_embeddings(self, new_embeddings):
492
+ self.diffusion_head = new_embeddings
493
+
494
+ def forward_process_complementary(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
495
+ device = input_ids.device
496
+
497
+ if self.config.dp_varying_mask_ratio:
498
+ import torch.distributed as dist
499
+ dp_rank = 0
500
+ if dist.is_initialized():
501
+ try:
502
+ dp_rank = dist.get_rank()
503
+ except Exception:
504
+ dp_rank = 0
505
+ generator = torch.Generator(device=device)
506
+ generator.manual_seed(torch.seed() + dp_rank)
507
+ else:
508
+ generator = None
509
+
510
+ noisy_input_ids = input_ids.clone()
511
+ input_ids_flat = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // block_size, block_size)
512
+ b, l = input_ids_flat.shape
513
+ t = torch.rand((b,), device=input_ids.device, generator=generator)
514
+ p_mask = (1 - eps) * t + eps
515
+ p_mask = p_mask[:, None].repeat(1, l)
516
+
517
+ masked_indices = (torch.rand((b, l), device=input_ids.device, generator=generator) < p_mask).reshape(noisy_input_ids.shape)
518
+ input_ids_flat = input_ids_flat.reshape(noisy_input_ids.shape)
519
+
520
+ complementary_noisy_input_ids = input_ids.clone()
521
+ complementary_masked_indices = ~masked_indices
522
+
523
+ if getattr(self.config, 'always_mask_im_end', False):
524
+ im_end_mask = (input_ids == self.config.im_end_token_id)
525
+ masked_indices = masked_indices | im_end_mask
526
+ complementary_masked_indices = complementary_masked_indices | im_end_mask
527
+
528
+ if loss_mask is not None:
529
+ masked_indices[loss_mask == 0] = 0
530
+ complementary_masked_indices[loss_mask == 0] = 0
531
+
532
+ noisy_input_ids[masked_indices] = self.mask_token_id
533
+ complementary_noisy_input_ids[complementary_masked_indices] = self.mask_token_id
534
+
535
+ noisy_input_ids = torch.cat([noisy_input_ids, complementary_noisy_input_ids], dim=0)
536
+ masked_indices = torch.cat([masked_indices, complementary_masked_indices], dim=0)
537
+ return noisy_input_ids, masked_indices, None
538
+
539
+ # ── Vision / multimodal helpers (ported from Mistral3Model) ──────────
540
+
541
+ IMAGE_TOKEN_ID = 19
542
+
543
+ def get_image_features(
544
+ self,
545
+ pixel_values: torch.FloatTensor,
546
+ image_sizes: torch.Tensor,
547
+ ) -> torch.FloatTensor:
548
+ """
549
+ Run the vision tower + multimodal projector and return a flat tensor
550
+ of image features ready to be scattered into the text embeddings.
551
+
552
+ Mirrors ``Mistral3Model.get_image_features`` from
553
+ transformers/models/mistral3/modeling_mistral3.py.
554
+
555
+ Returns:
556
+ Flat (total_image_tokens, hidden_size) tensor.
557
+ """
558
+ vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
559
+
560
+ image_outputs = self.encoder.vision_tower(
561
+ pixel_values,
562
+ image_sizes=image_sizes,
563
+ output_hidden_states=True,
564
+ return_dict=True,
565
+ )
566
+
567
+ if isinstance(vision_feature_layer, int):
568
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
569
+ else:
570
+ hs_pool = [image_outputs.hidden_states[idx] for idx in vision_feature_layer]
571
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
572
+
573
+ image_features = self.encoder.multi_modal_projector(
574
+ selected_image_feature.squeeze(0), image_sizes,
575
+ )
576
+
577
+ # Split per image, then re-cat into one flat tensor
578
+ downsample_ratio = (
579
+ self.encoder.vision_tower.patch_size
580
+ * getattr(self.config, "spatial_merge_size", 2)
581
+ )
582
+ split_sizes = (
583
+ (torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio)
584
+ .prod(dim=-1)
585
+ .tolist()
586
+ )
587
+ # per_image = torch.split(image_features.squeeze(0), split_sizes)
588
+ per_image = torch.split(image_features, split_sizes)
589
+
590
+ return torch.cat(per_image, dim=0) # (total_tokens, hidden)
591
+
592
+ def _is_vision_frozen(self) -> bool:
593
+ """True if vision_tower and multi_modal_projector have no parameters requiring grad (e.g. --freeze_vision_encoder)."""
594
+ vt = self.encoder.vision_tower
595
+ proj = self.encoder.multi_modal_projector
596
+ vt_has_grad = any(p.requires_grad for p in vt.parameters())
597
+ proj_has_grad = any(p.requires_grad for p in proj.parameters())
598
+ return not vt_has_grad and not proj_has_grad
599
+
600
+ def _embed_with_vision(
601
+ self,
602
+ input_ids: torch.LongTensor,
603
+ pixel_values: torch.FloatTensor,
604
+ image_sizes: torch.Tensor,
605
+ ) -> torch.FloatTensor:
606
+ """
607
+ Embed *input_ids* and scatter vision features into [IMG] pad positions.
608
+
609
+ Returns:
610
+ inputs_embeds (batch, seq_len, hidden_size)
611
+ """
612
+ inputs_embeds = self.encoder.embed_tokens(input_ids)
613
+ image_features = self.get_image_features(pixel_values, image_sizes)
614
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
615
+
616
+ # Boolean mask: positions that are IMG pad tokens
617
+ special_image_mask = (input_ids == self.IMAGE_TOKEN_ID)
618
+
619
+ if self.training:
620
+ if self.config.complementary_mask:
621
+ image_features = image_features.repeat(2, 1)
622
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
623
+ image_features = image_features.repeat(2, 1)
624
+
625
+ assert special_image_mask.sum() == image_features.shape[0], f"special_image_mask.sum() = {special_image_mask.sum()}, image_features.shape[0] = {image_features.shape[0]}"
626
+ # Expand to hidden dim for masked_scatter
627
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds)
628
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
629
+ return inputs_embeds
630
+
631
+ def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
632
+ b, l = input_ids.shape
633
+ device = input_ids.device
634
+
635
+ if self.config.dp_varying_mask_ratio:
636
+ # Enable different random seeds for each DP rank during sampling
637
+ import torch.distributed as dist
638
+ dp_rank = 0
639
+ if dist.is_initialized():
640
+ try:
641
+ dp_rank = dist.get_rank()
642
+ except Exception:
643
+ dp_rank = 0
644
+ # Use a local generator to avoid affecting global RNG state
645
+ generator = torch.Generator(device=device)
646
+ generator.manual_seed(torch.seed() + dp_rank)
647
+ else:
648
+ generator = None
649
+
650
+ if self.config.adaptive_mask_rate:
651
+ assert block_size is not None
652
+
653
+ # --- simple linear window mapping ---
654
+ bs_min = getattr(self.config, "t_bs_min", 16)
655
+ bs_max = getattr(self.config, "t_bs_max", 128)
656
+ w = getattr(self.config, "t_window_width", 0.6) # fixed width
657
+
658
+ # fraction in [0,1] (unclamped first)
659
+ frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
660
+ # upper bound decreases linearly from 1.0 -> 0.5
661
+ u_max = 1.0 - w * frac
662
+ # clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
663
+ u_max = max(0.6, min(1.0, u_max))
664
+ u_min = u_max - w # ensures width = w
665
+
666
+ # sample t ~ Uniform(u_min, u_max)
667
+ t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
668
+ else:
669
+ t = torch.rand(b, device=device, generator=generator)
670
+
671
+ p_mask = (1 - eps) * t + eps # shape: (b,)
672
+ p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
673
+
674
+ masked_indices = torch.rand((b, l), device=device) < p_mask
675
+
676
+ if loss_mask is not None:
677
+ masked_indices[loss_mask == 0] = 0
678
+
679
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
680
+
681
+ return noisy_batch, masked_indices, p_mask
682
+
683
+
684
+ def forward_process_exp(
685
+ self,
686
+ input_ids: torch.Tensor,
687
+ eps: float = 1e-3,
688
+ block_size: int | None = None,
689
+ half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
690
+ loss_mask: Optional[torch.Tensor] = None,
691
+ ):
692
+ """
693
+ Two-stage corruption with optional per-block sampling.
694
+ • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
695
+ • Stage 2: sample exactly k positions with weights
696
+ w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
697
+ uniform when m→1).
698
+ If `block_size` is given, the procedure is run *independently*
699
+ inside each contiguous block of that length (last block may be shorter).
700
+ When block_size is provided, m is sampled per-block and p_mask is per-block.
701
+ Args
702
+ ----
703
+ input_ids : (B, L) LongTensor
704
+ eps : minimum corruption ratio
705
+ block_size: if not None, operate block-wise with per-block m sampling
706
+ half_life_ratio : controls steepness when m→0
707
+ """
708
+ B, L = input_ids.shape
709
+ device = input_ids.device
710
+ dtype = torch.float32
711
+
712
+ masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
713
+ p_mask = torch.zeros((B, L), dtype=dtype, device=device)
714
+
715
+ # ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
716
+ for b in range(B):
717
+ if block_size is None:
718
+ # ---------- Per-batch sampling (original behavior) ----------
719
+ m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
720
+ k_tot = int(round(m * L))
721
+ k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
722
+
723
+ # Fill p_mask for this batch
724
+ p_mask[b, :] = m
725
+
726
+ slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
727
+
728
+ # ------- single pool over the whole sentence -------------
729
+ lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
730
+
731
+ pos = torch.arange(L, device=device, dtype=dtype)
732
+ log_w = (lam_base * slope * pos).clone()
733
+
734
+ masked_indices[b] = gumbel_topk(log_w, k_tot)
735
+
736
+ else:
737
+ # ---------- Per-block sampling ----------
738
+ num_blocks = math.ceil(L / block_size)
739
+ lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
740
+
741
+ for blk in range(num_blocks):
742
+ start = blk * block_size
743
+ end = min((blk + 1) * block_size, L)
744
+ blk_len = end - start
745
+
746
+ # Sample m per block
747
+ m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
748
+
749
+ # Fill p_mask for this block
750
+ p_mask[b, start:end] = m_blk
751
+
752
+ # per-block budget
753
+ k_blk = int(round(m_blk * blk_len))
754
+ k_blk = max(0, min(k_blk, blk_len))
755
+ if k_blk == 0:
756
+ continue
757
+
758
+ slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
759
+
760
+ pos = torch.arange(blk_len, device=device, dtype=dtype)
761
+ log_w = lam_base * slope * pos
762
+
763
+ blk_mask = gumbel_topk(log_w, k_blk)
764
+ masked_indices[b, start:end] = blk_mask
765
+
766
+ if loss_mask is not None:
767
+ masked_indices[loss_mask == 0] = 0
768
+
769
+ noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
770
+ return noisy_batch, masked_indices, p_mask
771
+
772
+
773
+ def forward(
774
+ self,
775
+ input_ids: torch.LongTensor,
776
+ attention_mask: Optional[torch.Tensor] = None,
777
+ position_ids: Optional[torch.LongTensor] = None,
778
+ labels: Optional[torch.LongTensor] = None,
779
+ split_len: Optional[int] = None,
780
+ past_key_values: Optional[Cache] = None,
781
+ block_size: Optional[int] = None,
782
+ block_diff_ppl: bool = False,
783
+ eps: float = 1e-3,
784
+ is_teacher: bool = False,
785
+ masked_indices: Optional[torch.Tensor] = None,
786
+ p_mask: Optional[torch.Tensor] = None,
787
+ teacher_logits: Optional[torch.Tensor] = None,
788
+ masked_indices_teacher: Optional[torch.Tensor] = None,
789
+ loss_mask: Optional[torch.Tensor] = None,
790
+ ce_loss_weight: float = 1.0,
791
+ output_last_hidden_states_only: bool = False,
792
+ skip_loss: bool = False,
793
+ pixel_values: Optional[torch.FloatTensor] = None,
794
+ image_sizes: Optional[torch.Tensor] = None,
795
+ **kwargs,
796
+ ) -> CausalLMOutputWithPast:
797
+
798
+ batch_size, seq_len = input_ids.shape
799
+
800
+ if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
801
+ if labels is not None and torch.rand(1) < self.config.random_length_prob:
802
+ random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
803
+ input_ids = input_ids[:, :random_length]
804
+ labels = labels[:, :random_length]
805
+
806
+ if attention_mask is not None:
807
+ attention_mask = attention_mask[:, :random_length]
808
+ if position_ids is not None:
809
+ position_ids = position_ids[:, :random_length]
810
+ if loss_mask is not None:
811
+ loss_mask = loss_mask[:, :random_length]
812
+
813
+ elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
814
+ if labels is not None and block_size is None:
815
+ if torch.rand(1) < self.config.random_length_prob:
816
+ block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
817
+ else:
818
+ block_size = self.config.block_size
819
+
820
+ else:
821
+ raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
822
+
823
+ if labels is not None and self.config.dlm_paradigm != 'autoregressive':
824
+ if masked_indices is not None:
825
+ # assert p_mask is not None
826
+
827
+ if loss_mask is not None:
828
+ masked_indices[loss_mask == 0] = 0
829
+
830
+ noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
831
+
832
+ else:
833
+ if self.config.complementary_mask:
834
+ loss_mask = (labels != -100)
835
+ noisy_inputs, masked_indices, p_mask = self.forward_process_complementary(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
836
+ else:
837
+ if self.config.tok_mask_half_life_ratio is not None:
838
+ noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
839
+ else:
840
+ noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
841
+
842
+ else:
843
+ noisy_inputs = input_ids
844
+ masked_indices = None
845
+ p_mask = None
846
+
847
+ if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
848
+ for layer in self.encoder.layers:
849
+ if hasattr(layer.self_attn, 'set_attention_mode'):
850
+ layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
851
+
852
+ input_ids_len = noisy_inputs.shape[1]
853
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
854
+ if position_ids is None:
855
+ position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
856
+ if self.config.complementary_mask:
857
+ noisy_inputs = torch.cat([noisy_inputs, torch.cat([input_ids, input_ids], dim=0)], dim=1)
858
+ else:
859
+ noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
860
+
861
+ if block_diff_ppl:
862
+ if position_ids is None:
863
+ position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
864
+
865
+ # ── Vision: replace IMG pad embeddings with image features ────────
866
+ if pixel_values is not None and image_sizes is not None:
867
+ inputs_embeds = self._embed_with_vision(noisy_inputs, pixel_values, image_sizes)
868
+ enc_out = self.encoder(
869
+ past_key_values=past_key_values,
870
+ inputs_embeds=inputs_embeds,
871
+ attention_mask=attention_mask,
872
+ position_ids=position_ids,
873
+ is_training=(labels is not None) or (block_diff_ppl),
874
+ **kwargs,
875
+ )
876
+ elif self.training and pixel_values is None and not self._is_vision_frozen():
877
+ vt = self.encoder.vision_tower
878
+ _p = vt.patch_size
879
+ _merge = getattr(self.config, "spatial_merge_size", 2)
880
+ _side = _p * _merge
881
+ _c = getattr(vt.config, "num_channels", 3)
882
+ _dtype = next(vt.parameters()).dtype
883
+ dummy_pixel = torch.zeros(
884
+ 1, _c, _side, _side,
885
+ dtype=_dtype, device=noisy_inputs.device,
886
+ )
887
+ dummy_image_sizes = torch.tensor(
888
+ [(int(_side), int(_side))],
889
+ dtype=torch.long, device=noisy_inputs.device,
890
+ )
891
+ dummy_features = self.get_image_features(dummy_pixel, dummy_image_sizes)
892
+ inputs_embeds = self.encoder.embed_tokens(noisy_inputs)
893
+ inputs_embeds = inputs_embeds + dummy_features.sum() * 0
894
+ enc_out = self.encoder(
895
+ past_key_values=past_key_values,
896
+ inputs_embeds=inputs_embeds,
897
+ attention_mask=attention_mask,
898
+ position_ids=position_ids,
899
+ is_training=(labels is not None) or (block_diff_ppl),
900
+ **kwargs,
901
+ )
902
+ else:
903
+ enc_out = self.encoder(
904
+ past_key_values=past_key_values,
905
+ input_ids=noisy_inputs,
906
+ attention_mask=attention_mask,
907
+ position_ids=position_ids,
908
+ is_training=(labels is not None) or (block_diff_ppl),
909
+ **kwargs,
910
+ )
911
+
912
+ if output_last_hidden_states_only:
913
+ return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
914
+
915
+ logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
916
+ causal_logits = None
917
+
918
+ if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
919
+ if self.config.dlm_paradigm == 'sbd_block_diff':
920
+ causal_logits = logits[:, input_ids_len:]
921
+ else:
922
+ causal_logits = None
923
+
924
+ logits = logits[:, :input_ids_len]
925
+
926
+ loss = None
927
+ if getattr(self.config, 'complementary_mask', False) and self.config.dlm_paradigm == 'sbd_block_diff':
928
+ _raw_nib = kwargs.get('num_items_in_batch', None)
929
+ kwargs = {**kwargs, 'num_items_in_batch': 2 * kwargs.get('num_items_in_batch', 1)}
930
+ if self.training and (not hasattr(self, '_nib_logged') or not self._nib_logged):
931
+ import torch.distributed as dist
932
+ _rank = dist.get_rank() if dist.is_initialized() else 0
933
+ if _rank == 0:
934
+ print(f"[DEBUG-NIB] raw num_items_in_batch from Trainer: {_raw_nib}, "
935
+ f"after 2x: {kwargs['num_items_in_batch']}, "
936
+ f"labels non-(-100): {(labels != -100).sum().item() if labels is not None else 'N/A'}, "
937
+ f"batch_size={input_ids.shape[0]}, seq_len={input_ids.shape[1]}", flush=True)
938
+ self._nib_logged = True
939
+ if labels is not None and not skip_loss:
940
+ if self.config.dlm_paradigm == 'autoregressive':
941
+ shift_logits = logits[..., :-1, :].contiguous()
942
+ shift_labels = labels[..., 1:].contiguous()
943
+
944
+ if loss_mask is None:
945
+ loss_fct = CrossEntropyLoss()
946
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
947
+ shift_labels = shift_labels.view(-1)
948
+ loss = loss_fct(shift_logits, shift_labels)
949
+
950
+ else:
951
+ loss_mask = loss_mask[..., 1:].contiguous()
952
+
953
+ loss_fct = CrossEntropyLoss(reduction='none')
954
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
955
+ shift_labels = shift_labels.view(-1)
956
+ shift_labels = shift_labels.to(shift_logits.device)
957
+
958
+ token_losses = loss_fct(shift_logits, shift_labels)
959
+
960
+ flat_loss_mask = loss_mask.reshape(-1)
961
+ loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
962
+
963
+ else:
964
+ # Handle DREAM vs LLADA style losses
965
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
966
+ logits = logits[..., :-1, :].contiguous()
967
+ labels = labels[..., 1:].contiguous()
968
+ masked_indices = masked_indices[:, 1:]
969
+ if p_mask is not None:
970
+ p_mask = p_mask[:, 1:]
971
+
972
+ if self.config.ada_perm_ratio_per_block is not None:
973
+ # Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
974
+ block_size = self.config.block_size
975
+ batch_size, seq_len = masked_indices.shape
976
+ num_blocks = seq_len // block_size
977
+
978
+ # Get the max logit (confidence) for each position
979
+ confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
980
+
981
+ # Create a mask for tokens to include in loss
982
+ selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
983
+
984
+ for blk in range(num_blocks):
985
+ start = blk * block_size
986
+ end = min((blk + 1) * block_size, seq_len)
987
+
988
+ # Get masked indices within this block
989
+ block_masked = masked_indices[:, start:end] # (batch_size, block_len)
990
+ block_confidence = confidence[:, start:end] # (batch_size, block_len)
991
+
992
+ for b in range(batch_size):
993
+ # Get positions that are masked in this block for this batch
994
+ masked_positions = torch.where(block_masked[b])[0]
995
+ num_masked = len(masked_positions)
996
+
997
+ if num_masked > 0:
998
+ # Number of tokens to keep (top by confidence)
999
+ k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
1000
+
1001
+ # Get confidence values for masked positions
1002
+ masked_confidence = block_confidence[b, masked_positions]
1003
+
1004
+ # Get indices of top-k confident tokens
1005
+ _, topk_indices = torch.topk(masked_confidence, k)
1006
+ selected_positions = masked_positions[topk_indices]
1007
+
1008
+ # Mark these positions in the selected mask
1009
+ selected_mask[b, start + selected_positions] = True
1010
+
1011
+ # Calculate loss only for selected positions
1012
+ token_loss = torch.nn.functional.cross_entropy(
1013
+ logits[selected_mask],
1014
+ labels[selected_mask],
1015
+ reduction='none'
1016
+ ) / p_mask[selected_mask]
1017
+
1018
+ num_mask_tokens = selected_mask.sum()
1019
+
1020
+ elif getattr(self.config, 'complementary_mask', False):
1021
+ token_loss = self.mdm_loss_function(
1022
+ logits=logits[masked_indices],
1023
+ labels=torch.cat([labels, labels], dim=0)[masked_indices],
1024
+ vocab_size=self.config.vocab_size,
1025
+ **kwargs
1026
+ )
1027
+ num_mask_tokens = masked_indices.sum()
1028
+
1029
+ else:
1030
+ # Calculate token-wise cross entropy loss for masked positions in B
1031
+ token_loss = torch.nn.functional.cross_entropy(
1032
+ logits[masked_indices],
1033
+ labels[masked_indices],
1034
+ reduction='none'
1035
+ ) / p_mask[masked_indices]
1036
+
1037
+ num_mask_tokens = masked_indices.sum()
1038
+
1039
+ if self.config.global_loss_avg:
1040
+ loss = token_loss.sum()
1041
+ else:
1042
+ loss = token_loss.sum() / num_mask_tokens
1043
+
1044
+ if self.config.ada_dlm_loss_ratio is not None:
1045
+ assert self.current_iter_ratio is not None
1046
+ assert self.config.dlm_loss_weight is not None
1047
+
1048
+ dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
1049
+ loss = dlm_loss_weight * loss
1050
+ elif self.config.dlm_loss_weight is not None:
1051
+ loss = self.config.dlm_loss_weight * loss
1052
+
1053
+ if self.config.dlm_paradigm == 'sbd_block_diff':
1054
+
1055
+ if getattr(self.config, 'complementary_mask', False):
1056
+ ar_loss = self.causal_loss_function(
1057
+ logits=causal_logits[:logits.shape[0] // 2, :],
1058
+ labels=labels,
1059
+ vocab_size=self.config.vocab_size,
1060
+ **kwargs
1061
+ )
1062
+
1063
+ _diff_val = loss.detach().item()
1064
+ _ar_val = ar_loss.detach().item()
1065
+ if not hasattr(self, '_loss_accum_count'):
1066
+ self._loss_diff_accum = 0.0
1067
+ self._loss_ar_accum = 0.0
1068
+ self._loss_accum_count = 0
1069
+ self._loss_diff_accum += _diff_val
1070
+ self._loss_ar_accum += _ar_val
1071
+ self._loss_accum_count += 1
1072
+ self.loss_diffusion = self._loss_diff_accum
1073
+ self.loss_ar = self._loss_ar_accum
1074
+
1075
+ loss = loss + ar_loss
1076
+ else:
1077
+ causal_logits = causal_logits[..., :-1, :].contiguous()
1078
+ causal_logits = causal_logits.view(-1, causal_logits.size(-1))
1079
+
1080
+ if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
1081
+ causal_labels = labels.view(-1)
1082
+ else:
1083
+ causal_labels = labels[..., 1:].contiguous().view(-1)
1084
+
1085
+ if self.config.global_loss_avg:
1086
+ loss_fct = CrossEntropyLoss(reduction='sum')
1087
+ ar_loss = loss_fct(causal_logits, causal_labels)
1088
+
1089
+ self.loss_diffusion = loss.detach().item() / num_mask_tokens
1090
+ self.loss_ar = ar_loss.detach().item() / seq_len
1091
+
1092
+ loss = loss + self.config.ar_loss_weight * ar_loss
1093
+
1094
+ else:
1095
+ loss_fct = CrossEntropyLoss()
1096
+ ar_loss = loss_fct(causal_logits, causal_labels)
1097
+
1098
+ self.loss_diffusion = loss.detach().item()
1099
+ self.loss_ar = ar_loss.detach().item()
1100
+
1101
+ loss = loss + self.config.ar_loss_weight * ar_loss
1102
+
1103
+ # if self.config.global_loss_avg:
1104
+ # if self.config.dlm_paradigm == 'sbd_block_diff':
1105
+ # loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
1106
+ # else:
1107
+ # loss = (loss, num_mask_tokens)
1108
+
1109
+ return NemotronLabsDiffusionVLMOutputWithPast(
1110
+ loss=loss if not is_teacher else logits,
1111
+ logits=logits,
1112
+ causal_logits=causal_logits,
1113
+ past_key_values=enc_out.past_key_values,
1114
+ hidden_states=None,
1115
+ attentions=None,
1116
+ )
1117
+
1118
+
1119
+ def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold,
1120
+ causal_context=True, temperature=0, pixel_values=None, image_sizes=None, eos_token_id=None):
1121
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
1122
+ model=self,
1123
+ prompt=prompt_ids,
1124
+ gen_length=max_new_tokens,
1125
+ steps=steps,
1126
+ block_length=block_length,
1127
+ remasking="low_confidence",
1128
+ temperature=temperature,
1129
+ mask_id=self.mask_token_id,
1130
+ threshold=threshold,
1131
+ shift_logits=shift_logits,
1132
+ neg_entropy=False,
1133
+ causal_context=causal_context,
1134
+ pixel_values=pixel_values,
1135
+ image_sizes=image_sizes,
1136
+ eos_token_id=eos_token_id,
1137
+ )
1138
+
1139
+ return out_ids, nfe
1140
+
1141
+ @torch.no_grad()
1142
+ def sbd_inference_diffusion_quadratic(
1143
+ self,
1144
+ clean_input_ids: Optional[torch.Tensor],
1145
+ draft_input_ids: torch.Tensor,
1146
+ block_length: int,
1147
+ draft_only: bool = False,
1148
+ past_key_values: Optional[Cache] = None,
1149
+ use_cache: bool = False,
1150
+ pixel_values: Optional[torch.FloatTensor] = None,
1151
+ image_sizes: Optional[torch.Tensor] = None,
1152
+ ):
1153
+ enc_config = self.encoder.config
1154
+ enc_config.use_sbd_objective = True
1155
+ enc_config.block_length = block_length
1156
+
1157
+ if draft_only:
1158
+ assert clean_input_ids is not None
1159
+
1160
+ if use_cache and past_key_values is None:
1161
+ past_key_values = DynamicCache()
1162
+
1163
+ enc_config.self_spec_inference_mode = "default"
1164
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
1165
+ if pixel_values is not None and image_sizes is not None:
1166
+ inputs_embeds = self._embed_with_vision(input_ids, pixel_values, image_sizes)
1167
+ outputs = self.encoder(
1168
+ inputs_embeds=inputs_embeds,
1169
+ position_ids=None,
1170
+ past_key_values=past_key_values,
1171
+ use_cache=use_cache,
1172
+ is_training=False,
1173
+ )
1174
+ else:
1175
+ outputs = self.encoder(
1176
+ input_ids=input_ids,
1177
+ position_ids=None,
1178
+ past_key_values=past_key_values,
1179
+ use_cache=use_cache,
1180
+ is_training=False,
1181
+ )
1182
+
1183
+ hidden_states = outputs.last_hidden_state
1184
+ logits = self.diffusion_head(hidden_states)
1185
+
1186
+ past_key_values = getattr(outputs, "past_key_values", None)
1187
+ if use_cache and past_key_values is not None:
1188
+ _crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
1189
+
1190
+ return logits, past_key_values
1191
+ else:
1192
+ enc_config.self_spec_inference_mode = "quadratic"
1193
+
1194
+ draft_len = block_length * (block_length + 1)
1195
+ draft_input_ids = torch.cat(
1196
+ [
1197
+ draft_input_ids.view(-1, block_length, 1),
1198
+ torch.full(
1199
+ (draft_input_ids.shape[0], block_length, block_length),
1200
+ fill_value=self.config.mask_token_id,
1201
+ device=draft_input_ids.device,
1202
+ ),
1203
+ ],
1204
+ dim=-1,
1205
+ ).view(-1, draft_len)
1206
+
1207
+ if use_cache:
1208
+ assert past_key_values is not None, (
1209
+ "Past key values should be provided when using cache, e.g. run draft_only=True first."
1210
+ )
1211
+ assert clean_input_ids is None, (
1212
+ "Clean input ids should already be in cache, thus none should be provided."
1213
+ )
1214
+ clean_len = past_key_values.get_seq_length()
1215
+ input_ids = draft_input_ids
1216
+ else:
1217
+ clean_len = clean_input_ids.shape[1]
1218
+ input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
1219
+
1220
+ per_block_position_ids = torch.arange(
1221
+ clean_len, clean_len + block_length + 1, device=draft_input_ids.device
1222
+ )[None,].repeat(block_length, 1)
1223
+ per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
1224
+
1225
+ if use_cache:
1226
+ position_ids = per_block_position_ids.view(-1)[None,]
1227
+ else:
1228
+ clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
1229
+ position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
1230
+
1231
+ if pixel_values is not None and image_sizes is not None and not use_cache:
1232
+ inputs_embeds = self._embed_with_vision(input_ids, pixel_values, image_sizes)
1233
+ outputs = self.encoder(
1234
+ inputs_embeds=inputs_embeds,
1235
+ position_ids=position_ids,
1236
+ past_key_values=past_key_values,
1237
+ use_cache=use_cache,
1238
+ is_training=False,
1239
+ )
1240
+ else:
1241
+ outputs = self.encoder(
1242
+ input_ids=input_ids,
1243
+ position_ids=position_ids,
1244
+ past_key_values=past_key_values,
1245
+ use_cache=use_cache,
1246
+ is_training=False,
1247
+ )
1248
+
1249
+ hidden_states = outputs.last_hidden_state
1250
+ logits = self.diffusion_head(hidden_states)
1251
+ past_key_values = getattr(outputs, "past_key_values", None)
1252
+
1253
+ if use_cache and past_key_values is not None:
1254
+ _extract_draft_kv_cache(past_key_values, clean_len, block_length)
1255
+
1256
+ return logits, past_key_values
1257
+
1258
+ @torch.no_grad()
1259
+ def self_spec_generate(
1260
+ self,
1261
+ prompt_ids: torch.Tensor,
1262
+ max_new_tokens: int = 128,
1263
+ steps: int = 128,
1264
+ block_length: int = 16,
1265
+ ar_mix_weight: Optional[float] = None,
1266
+ temperature: float = 0.0,
1267
+ mask_token_id: Optional[int] = None,
1268
+ eos_token_id: Optional[int] = None,
1269
+ pixel_values: Optional[torch.FloatTensor] = None,
1270
+ image_sizes: Optional[torch.Tensor] = None,
1271
+ ):
1272
+ self.config.use_sbd_objective = True
1273
+ self.config.dlm_paradigm = "sbd"
1274
+
1275
+ if prompt_ids.shape[0] != 1:
1276
+ raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
1277
+
1278
+ token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
1279
+ if eos_token_id is None:
1280
+ eos_token_id = getattr(self.config, "eos_token_id", None)
1281
+
1282
+ x = torch.full(
1283
+ (1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
1284
+ token_mask_id,
1285
+ dtype=torch.long,
1286
+ device=prompt_ids.device,
1287
+ )
1288
+ x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
1289
+
1290
+ if max_new_tokens % block_length != 0:
1291
+ raise ValueError("max_new_tokens must be divisible by block_length")
1292
+ num_blocks = max_new_tokens // block_length
1293
+ if steps % num_blocks != 0:
1294
+ raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
1295
+
1296
+ prompt_len = prompt_ids.shape[1]
1297
+ nfe = 0
1298
+ nfe += 1
1299
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1300
+ clean_input_ids=x[:, :prompt_len],
1301
+ draft_input_ids=x[:, prompt_len : prompt_len + block_length],
1302
+ block_length=block_length,
1303
+ draft_only=True,
1304
+ use_cache=True,
1305
+ pixel_values=pixel_values,
1306
+ image_sizes=image_sizes,
1307
+ )
1308
+
1309
+ logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
1310
+ logits_proposal[:, 1] = logits_proposal[:, 0]
1311
+ logits_proposal = logits_proposal[:, 1:]
1312
+ x0_proposal = torch.argmax(logits_proposal, dim=-1)
1313
+ x[:, prompt_len : prompt_len + block_length] = x0_proposal
1314
+
1315
+ total_accept_token = 0
1316
+ while True:
1317
+ nfe += 1
1318
+ block_start = prompt_len + total_accept_token
1319
+ block_end = block_start + block_length
1320
+ draft_input_ids = x[:, block_start:block_end]
1321
+
1322
+ logits, past_key_values = self.sbd_inference_diffusion_quadratic(
1323
+ clean_input_ids=None,
1324
+ draft_input_ids=draft_input_ids,
1325
+ block_length=block_length,
1326
+ draft_only=False,
1327
+ past_key_values=past_key_values,
1328
+ use_cache=True,
1329
+ pixel_values=pixel_values,
1330
+ image_sizes=image_sizes,
1331
+ )
1332
+
1333
+ useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
1334
+ if ar_mix_weight is None:
1335
+ useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
1336
+ else:
1337
+ if not (0.0 <= ar_mix_weight <= 1.0):
1338
+ raise ValueError("ar_mix_weight must be between 0 and 1")
1339
+ mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
1340
+ useful_token_logits[:, :, 0] = mix_logits
1341
+ useful_token_logits[:, :, 1] = mix_logits
1342
+
1343
+ if temperature > 0:
1344
+ useful_token_logits = useful_token_logits / temperature
1345
+
1346
+ useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
1347
+ new_draft_input_ids = useful_token_pred[:, 0, 1:]
1348
+ accept_cnt = 1
1349
+
1350
+ while accept_cnt < block_length:
1351
+ if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
1352
+ break
1353
+ new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
1354
+ accept_cnt += 1
1355
+
1356
+ x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
1357
+
1358
+ # EoS early stopping
1359
+ if eos_token_id is not None:
1360
+ accepted = x[0, block_start : block_start + accept_cnt]
1361
+ eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
1362
+ if len(eos_positions) > 0:
1363
+ first_eos_rel = eos_positions[0].item()
1364
+ total_accept_token += first_eos_rel + 1
1365
+ output_end = prompt_len + total_accept_token
1366
+ return x[:, :output_end], nfe
1367
+
1368
+ x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
1369
+ _crop_dynamic_cache(past_key_values, block_start + accept_cnt)
1370
+ total_accept_token += accept_cnt
1371
+
1372
+ if total_accept_token >= max_new_tokens:
1373
+ break
1374
+
1375
+ return x[:, : -(block_length * 2)], nfe
1376
+
1377
+
1378
+ __all__ = ["NemotronLabsDiffusionVLMModel", "NemotronLabsDiffusionVLMFlexAttention"]
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "|<MASK>|"
4
+ ],
5
+ "bos_token": {
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eos_token": {
13
+ "content": "<|im_end|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "pad_token": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenization_nemotron_labs_diffusion_vlm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom tokenizer for Nemotron-Diffusion-Exp-Ministral-8B-Instruct (final-template).
3
+
4
+ Extends PreTrainedTokenizerFast with a `process_messages` method that
5
+ handles image token expansion and pixel value preprocessing, analogous
6
+ to MistralCommonBackend.apply_chat_template(return_dict=True).
7
+
8
+ Usage:
9
+ tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
10
+ result = tokenizer.process_messages(messages)
11
+ # result["input_ids"] – (1, seq_len) with expanded image tokens
12
+ # result["pixel_values"] – (N, 3, H, W) if images present
13
+ # result["image_sizes"] – list of (H, W) tuples
14
+ """
15
+
16
+ from typing import Any, Dict, List
17
+
18
+ from transformers import PreTrainedTokenizerFast
19
+
20
+ from .image_processing import process_messages as _process_messages
21
+
22
+
23
+ class NemotronLabsDiffusionVLMTokenizerFast(PreTrainedTokenizerFast):
24
+ """PreTrainedTokenizerFast + image-aware process_messages()."""
25
+
26
+ def process_messages(
27
+ self,
28
+ messages: List[Dict[str, Any]],
29
+ **kwargs,
30
+ ) -> Dict[str, Any]:
31
+ """
32
+ Process chat messages with optional images.
33
+
34
+ Renders the chat template, expands image placeholders based on
35
+ actual image dimensions, preprocesses pixel values, and tokenizes.
36
+
37
+ Args:
38
+ messages: OpenAI-style list of message dicts.
39
+ **kwargs: forwarded to image_processing.process_messages
40
+ (patch_size, spatial_merge_size, max_image_size,
41
+ return_tensors, enable_thinking).
42
+
43
+ Returns:
44
+ dict with input_ids, and optionally pixel_values + image_sizes.
45
+ """
46
+ return _process_messages(self, messages, **kwargs)
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e04613060199ab156b35bf2334e381748ae41311e8785efb330bc66e16670d8
3
+ size 17077689
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff