Image-Text-to-Text
Transformers
Safetensors
PyTorch
nemotron_labs_diffusion_vlm
feature-extraction
nvidia
multimodal
vlm
diffusion-language-model
conversational
custom_code
Instructions to use nvidia/Nemotron-Labs-Diffusion-VLM-8B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/Nemotron-Labs-Diffusion-VLM-8B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="nvidia/Nemotron-Labs-Diffusion-VLM-8B", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/Nemotron-Labs-Diffusion-VLM-8B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use nvidia/Nemotron-Labs-Diffusion-VLM-8B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "nvidia/Nemotron-Labs-Diffusion-VLM-8B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-VLM-8B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/nvidia/Nemotron-Labs-Diffusion-VLM-8B
- SGLang
How to use nvidia/Nemotron-Labs-Diffusion-VLM-8B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-Diffusion-VLM-8B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-VLM-8B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-Diffusion-VLM-8B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-Diffusion-VLM-8B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use nvidia/Nemotron-Labs-Diffusion-VLM-8B with Docker Model Runner:
docker model run hf.co/nvidia/Nemotron-Labs-Diffusion-VLM-8B
Commit ·
c6706ba
0
Parent(s):
Initial release of Nemotron-Labs-Diffusion-VLM-8B
Browse filesCo-authored-by: pmolchanov <pmolchanov@users.noreply.huggingface.co>
- .gitattributes +41 -0
- README.md +124 -0
- assets/demo.gif +3 -0
- assets/demo.mp4 +3 -0
- assets/result_acc.png +3 -0
- assets/result_efficiency.png +3 -0
- assets/teaser.png +3 -0
- chat_template.jinja +227 -0
- chat_utils.py +272 -0
- config.json +106 -0
- configuration_nemotron_labs_diffusion_vlm.py +259 -0
- generation_config.json +11 -0
- image_processing.py +296 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +539 -0
- model_cards/bias.md +4 -0
- model_cards/explainability.md +13 -0
- model_cards/privacy.md +11 -0
- model_cards/safety.md +6 -0
- modeling_ministral.py +629 -0
- modeling_nemotron_labs_diffusion_vlm.py +1378 -0
- special_tokens_map.json +33 -0
- tokenization_nemotron_labs_diffusion_vlm.py +46 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/result_acc.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/result_efficiency.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
license: other
|
| 4 |
+
license_name: nscl-v1
|
| 5 |
+
pipeline_tag: image-text-to-text
|
| 6 |
+
tags:
|
| 7 |
+
- nvidia
|
| 8 |
+
- pytorch
|
| 9 |
+
- multimodal
|
| 10 |
+
- vlm
|
| 11 |
+
- diffusion-language-model
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Nemotron-Labs-Diffusion-VLM-8B
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
<div align="center" style="line-height: 1;">
|
| 18 |
+
<a href="https://d1qx31qr3h6wln.cloudfront.net/publications/Nemotron_Diffusion_Tech_Report_v1.pdf?VersionId=db8_EMO8B.vmU26.jr7Le9pN3MqcUDNL" target="_blank" style="margin: 2px;">
|
| 19 |
+
<img alt="Chat" src="https://img.shields.io/badge/📝Paper-Read Now!-536af5?color=76B900&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 20 |
+
</a>
|
| 21 |
+
<a href="https://huggingface.co/collections/nvidia/nemotron-labs-diffusion" target="_blank" style="margin: 2px;">
|
| 22 |
+
<img alt="Nemotron-Labs-Diffusion Model Family" src="https://img.shields.io/badge/%F0%9F%A4%97-Nemotron--Labs--Diffusion_Model_Family-76B900" style="display: inline-block; vertical-align: middle;"/>
|
| 23 |
+
</a>
|
| 24 |
+
<a href="https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-source-code-license/" style="margin: 2px;">
|
| 25 |
+
<img alt="License" src="https://img.shields.io/badge/License-NSCLv1-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
| 26 |
+
</a>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
[](./assets/demo.mp4)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## Model Overview
|
| 34 |
+
|
| 35 |
+
Nemotron-Labs-Diffusion-VLM-8B is the vision-language extension of the Nemotron-Labs-Diffusion family. It pairs the same tri-mode language backbone (AR / diffusion / self-speculation, switchable by attention pattern) with a vision encoder, accepting interleaved image + text input and producing text output. The diffusion-based parallel decoding from the LM family carries over to VLM: the language head can draft a block in parallel and verify autoregressively against shared KV cache, retaining the family's decode-efficiency story while extending it to multimodal prompts.
|
| 36 |
+
|
| 37 |
+
<div align="center">
|
| 38 |
+
<img src="./assets/teaser.png" alt="An illustration of Tri-Mode LMs" width="500">
|
| 39 |
+
</div>
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## Key Design
|
| 43 |
+
|
| 44 |
+
- 8B vision-language model in the Nemotron-Labs-Diffusion family — same tri-mode language backbone (AR, diffusion, self-speculation) plus a Pixtral-style vision encoder.
|
| 45 |
+
- Vision encoder: 24-layer, 1024-hidden, 14×14 patch, supports up to 1540×1540 images with `spatial_merge_size=2`.
|
| 46 |
+
- Language decoder weights match `nvidia/Nemotron-Labs-Diffusion-8B` (34 layers, 4096 hidden, 14336 intermediate); the model card structure and inference modes inherit from the LM line.
|
| 47 |
+
- Diffusion-based parallel decoding works for multimodal prompts: image tokens are placed in the bidirectional context window and text generation proceeds via the same block-wise unmasking + AR verification as the LM family.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## License/Terms of Use
|
| 51 |
+
|
| 52 |
+
Use of this model is governed by the **NVIDIA Source Code License (NSCLv1)**.
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
## Environment
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
transformers>=5.0.0
|
| 59 |
+
pillow
|
| 60 |
+
requests
|
| 61 |
+
opencv-python
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## Chat with Our Model
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
import sys
|
| 69 |
+
import torch
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
from transformers import AutoModel, AutoTokenizer
|
| 72 |
+
|
| 73 |
+
repo_name = "nvidia/Nemotron-Labs-Diffusion-VLM-8B"
|
| 74 |
+
sys.path.insert(0, snapshot_download(repo_name))
|
| 75 |
+
from image_processing import process_messages
|
| 76 |
+
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
|
| 78 |
+
model = AutoModel.from_pretrained(repo_name, trust_remote_code=True).cuda().to(torch.bfloat16)
|
| 79 |
+
|
| 80 |
+
image_path = "path/to/your/image.jpg" # local file or http(s):// URL
|
| 81 |
+
messages = [{
|
| 82 |
+
"role": "user",
|
| 83 |
+
"content": [
|
| 84 |
+
{"type": "image_url", "image_url": {"url": image_path}},
|
| 85 |
+
{"type": "text", "text": "Describe this image."},
|
| 86 |
+
],
|
| 87 |
+
}]
|
| 88 |
+
|
| 89 |
+
batch = process_messages(tokenizer, messages, add_generation_prompt=True)
|
| 90 |
+
prompt_ids = batch["input_ids"].to("cuda")
|
| 91 |
+
pixel_values = batch["pixel_values"].to("cuda", dtype=torch.bfloat16)
|
| 92 |
+
|
| 93 |
+
out_ids, nfe = model.generate(
|
| 94 |
+
prompt_ids,
|
| 95 |
+
pixel_values=pixel_values,
|
| 96 |
+
image_sizes=batch["image_sizes"],
|
| 97 |
+
max_new_tokens=512, steps=512, block_length=32,
|
| 98 |
+
shift_logits=False, threshold=0.9,
|
| 99 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
tokenized_out = tokenizer.batch_decode(out_ids[:, prompt_ids.shape[1]:], skip_special_tokens=True)
|
| 103 |
+
print(f"Model: {tokenized_out[0]}")
|
| 104 |
+
print(f"[Num Function Eval (NFE)={nfe}]")
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
## Ethical Considerations
|
| 109 |
+
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the [bias](./model_cards/bias.md), [explainability](./model_cards/explainability.md), [safety & security](./model_cards/safety.md), and [privacy](./model_cards/privacy.md) subcards.
|
| 110 |
+
|
| 111 |
+
Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
## Citations
|
| 115 |
+
|
| 116 |
+
```bibtex
|
| 117 |
+
@techreport{fu2026nemotronlabsdiffusion,
|
| 118 |
+
title = {Nemotron-Labs-Diffusion: A Tri-Mode Language Model Unifying Autoregressive, Diffusion, and Self-Speculation Decoding},
|
| 119 |
+
author = {Yonggan Fu and Lexington Whalen and Abhinav Garg and Chengyue Wu and Maksim Khadkevich and Nicolai Oswald and Enze Xie and Daniel Egert and Sharath Turuvekere Sreenivas and Shizhe Diao and Chenhan Yu and Ye Yu and Weijia Chen and Sajad Norouzi and Shiyi Lan and Ligeng Zhu and Jin Wang and Jindong Jiang and Morteza Mardani and Mehran Maghoumi and Song Han and Ante Jukic and Nima Tajbakhsh and Jan Kautz and Pavlo Molchanov},
|
| 120 |
+
institution = {NVIDIA},
|
| 121 |
+
year = {2026},
|
| 122 |
+
note = {Technical report}
|
| 123 |
+
}
|
| 124 |
+
```
|
assets/demo.gif
ADDED
|
Git LFS Details
|
assets/demo.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:666d8785ac4af75931d9c677757c4ef9945bf114d07f1c4e2ebb7b893ac39006
|
| 3 |
+
size 9454873
|
assets/result_acc.png
ADDED
|
Git LFS Details
|
assets/result_efficiency.png
ADDED
|
Git LFS Details
|
assets/teaser.png
ADDED
|
Git LFS Details
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% macro render_extra_keys(json_dict, handled_keys) %}
|
| 2 |
+
{%- if json_dict is mapping %}
|
| 3 |
+
{%- for json_key in json_dict if json_key not in handled_keys %}
|
| 4 |
+
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
| 5 |
+
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
| 6 |
+
{%- else %}
|
| 7 |
+
{{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
| 8 |
+
{%- endif %}
|
| 9 |
+
{%- endfor %}
|
| 10 |
+
{%- endif %}
|
| 11 |
+
{% endmacro %}
|
| 12 |
+
{%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
|
| 13 |
+
{%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
|
| 14 |
+
|
| 15 |
+
{%- set ns = namespace(last_user_idx = -1) %}
|
| 16 |
+
{%- set loop_messages = messages %}
|
| 17 |
+
{%- for m in loop_messages %}
|
| 18 |
+
{%- if m["role"] == "user" %}
|
| 19 |
+
{%- set ns.last_user_idx = loop.index0 %}
|
| 20 |
+
{%- endif %}
|
| 21 |
+
{%- endfor %}
|
| 22 |
+
|
| 23 |
+
{%- if messages[0]["role"] == "system" %}
|
| 24 |
+
{%- set system_message = messages[0]["content"] %}
|
| 25 |
+
{%- set loop_messages = messages[1:] %}
|
| 26 |
+
{%- else %}
|
| 27 |
+
{%- set system_message = "" %}
|
| 28 |
+
{%- set loop_messages = messages %}
|
| 29 |
+
{%- endif %}
|
| 30 |
+
{%- if not tools is defined %}
|
| 31 |
+
{%- set tools = [] %}
|
| 32 |
+
{%- endif %}
|
| 33 |
+
{# Recompute last_user_idx relative to loop_messages after handling system #}
|
| 34 |
+
{%- set ns = namespace(last_user_idx = -1) %}
|
| 35 |
+
{%- for m in loop_messages %}
|
| 36 |
+
{%- if m["role"] == "user" %}
|
| 37 |
+
{%- set ns.last_user_idx = loop.index0 %}
|
| 38 |
+
{%- endif %}
|
| 39 |
+
{%- endfor %}
|
| 40 |
+
{%- if system_message is defined %}
|
| 41 |
+
{{- "<|im_start|>system\n" + system_message }}
|
| 42 |
+
{%- else %}
|
| 43 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 44 |
+
{{- "<|im_start|>system\n" }}
|
| 45 |
+
{%- endif %}
|
| 46 |
+
{%- endif %}
|
| 47 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 48 |
+
{%- if system_message is defined and system_message | length > 0 %}
|
| 49 |
+
{{- "\n\n" }}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{{- "# Tools\n\nYou have access to the following functions:\n\n" }}
|
| 52 |
+
{{- "<tools>" }}
|
| 53 |
+
{%- for tool in tools %}
|
| 54 |
+
{%- if tool.function is defined %}
|
| 55 |
+
{%- set tool = tool.function %}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
|
| 58 |
+
{%- if tool.description is defined %}
|
| 59 |
+
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
|
| 60 |
+
{%- endif %}
|
| 61 |
+
{{- '\n<parameters>' }}
|
| 62 |
+
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
| 63 |
+
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
| 64 |
+
{{- '\n<parameter>' }}
|
| 65 |
+
{{- '\n<name>' ~ param_name ~ '</name>' }}
|
| 66 |
+
{%- if param_fields.type is defined %}
|
| 67 |
+
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
| 68 |
+
{%- endif %}
|
| 69 |
+
{%- if param_fields.description is defined %}
|
| 70 |
+
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
|
| 71 |
+
{%- endif %}
|
| 72 |
+
{%- if param_fields.enum is defined %}
|
| 73 |
+
{{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
|
| 76 |
+
{{- render_extra_keys(param_fields, handled_keys) }}
|
| 77 |
+
{{- '\n</parameter>' }}
|
| 78 |
+
{%- endfor %}
|
| 79 |
+
{%- endif %}
|
| 80 |
+
{% set handled_keys = ['type', 'properties', 'required'] %}
|
| 81 |
+
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
| 82 |
+
{%- if tool.parameters is defined and tool.parameters.required is defined %}
|
| 83 |
+
{{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{{- '\n</parameters>' }}
|
| 86 |
+
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
| 87 |
+
{{- render_extra_keys(tool, handled_keys) }}
|
| 88 |
+
{{- '\n</function>' }}
|
| 89 |
+
{%- endfor %}
|
| 90 |
+
{{- "\n</tools>" }}
|
| 91 |
+
|
| 92 |
+
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
| 93 |
+
{%- endif %}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
{%- if system_message is defined %}
|
| 97 |
+
{{- '<|im_end|>\n' }}
|
| 98 |
+
{%- else %}
|
| 99 |
+
{%- if tools is iterable and tools | length > 0 %}
|
| 100 |
+
{{- '<|im_end|>\n' }}
|
| 101 |
+
{%- endif %}
|
| 102 |
+
{%- endif %}
|
| 103 |
+
|
| 104 |
+
{%- for message in loop_messages %}
|
| 105 |
+
{%- if message.role == "assistant" %}
|
| 106 |
+
{# Add reasoning content in to content field for unified processing below. #}
|
| 107 |
+
{%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
|
| 108 |
+
{%- set msg_content = message.content | default('', true) %}
|
| 109 |
+
{%- if msg_content is not string -%}
|
| 110 |
+
{%- set ns_rc = namespace(text='') -%}
|
| 111 |
+
{%- for block in msg_content -%}
|
| 112 |
+
{%- if block.type is defined and block.type == "text" -%}
|
| 113 |
+
{%- set ns_rc.text = ns_rc.text + block.text -%}
|
| 114 |
+
{%- endif -%}
|
| 115 |
+
{%- endfor -%}
|
| 116 |
+
{%- set msg_content = ns_rc.text -%}
|
| 117 |
+
{%- endif -%}
|
| 118 |
+
{%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ msg_content %}
|
| 119 |
+
{%- else %}
|
| 120 |
+
{%- set content = message.content | default('', true) %}
|
| 121 |
+
{%- if content is not string -%}
|
| 122 |
+
{%- set ns_c = namespace(text='') -%}
|
| 123 |
+
{%- for block in content -%}
|
| 124 |
+
{%- if block.type is defined and block.type == "text" -%}
|
| 125 |
+
{%- set ns_c.text = ns_c.text + block.text -%}
|
| 126 |
+
{%- endif -%}
|
| 127 |
+
{%- endfor -%}
|
| 128 |
+
{%- set content = ns_c.text -%}
|
| 129 |
+
{%- endif -%}
|
| 130 |
+
{%- if '<think>' not in content and '</think>' not in content -%}
|
| 131 |
+
{%- set content = "<think></think>" ~ content -%}
|
| 132 |
+
{%- endif -%}
|
| 133 |
+
{%- endif %}
|
| 134 |
+
{%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
|
| 135 |
+
{# Assistant message has tool calls. #}
|
| 136 |
+
{{- '<|im_start|>assistant\n' }}
|
| 137 |
+
{%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
|
| 138 |
+
{%- if content is string and content | trim | length > 0 %}
|
| 139 |
+
{%- if include_content %}
|
| 140 |
+
{{- (content | trim) ~ '\n' -}}
|
| 141 |
+
{%- else %}
|
| 142 |
+
{%- set c = (content | string) %}
|
| 143 |
+
{%- if '</think>' in c %}
|
| 144 |
+
{# Keep only content after the last closing think. Also generation prompt causes this. #}
|
| 145 |
+
{%- set c = c.split('</think>')[-1] %}
|
| 146 |
+
{%- elif '<think>' in c %}
|
| 147 |
+
{# If <think> was opened but never closed, drop the trailing think segment #}
|
| 148 |
+
{%- set c = c.split('<think>')[0] %}
|
| 149 |
+
{%- endif %}
|
| 150 |
+
{%- set c = "<think></think>" ~ c | trim %}
|
| 151 |
+
{%- if c | length > 0 %}
|
| 152 |
+
{{- c ~ '\n' -}}
|
| 153 |
+
{%- endif %}
|
| 154 |
+
{%- endif %}
|
| 155 |
+
{%- else %}
|
| 156 |
+
{{- "<think></think>" -}}
|
| 157 |
+
{%- endif %}
|
| 158 |
+
{%- for tool_call in message.tool_calls %}
|
| 159 |
+
{%- if tool_call.function is defined %}
|
| 160 |
+
{%- set tool_call = tool_call.function %}
|
| 161 |
+
{%- endif %}
|
| 162 |
+
{{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
|
| 163 |
+
{%- if tool_call.arguments is defined %}
|
| 164 |
+
{%- for args_name, args_value in tool_call.arguments|items %}
|
| 165 |
+
{{- '<parameter=' ~ args_name ~ '>\n' -}}
|
| 166 |
+
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 167 |
+
{{- args_value ~ '\n</parameter>\n' -}}
|
| 168 |
+
{%- endfor %}
|
| 169 |
+
{%- endif %}
|
| 170 |
+
{{- '</function>\n</tool_call>\n' -}}
|
| 171 |
+
{%- endfor %}
|
| 172 |
+
{{- '<|im_end|>\n' }}
|
| 173 |
+
{%- else %}
|
| 174 |
+
{# Assistant message doesn't have tool calls. #}
|
| 175 |
+
{%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
|
| 176 |
+
{{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
|
| 177 |
+
{%- else %}
|
| 178 |
+
{%- set c = (content | default('', true) | string) %}
|
| 179 |
+
{%- if '<think>' in c and '</think>' in c %}
|
| 180 |
+
{%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
|
| 181 |
+
{%- endif %}
|
| 182 |
+
{%- set c = c | trim %}
|
| 183 |
+
{%- if c | length > 0 %}
|
| 184 |
+
{{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
|
| 185 |
+
{%- else %}
|
| 186 |
+
{{- '<|im_start|>assistant\n<|im_end|>\n' }}
|
| 187 |
+
{%- endif %}
|
| 188 |
+
{%- endif %}
|
| 189 |
+
{%- endif %}
|
| 190 |
+
{%- elif message.role == "user" or message.role == "system" %}
|
| 191 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 192 |
+
{%- if message.content is string %}
|
| 193 |
+
{{- message.content }}
|
| 194 |
+
{%- else %}
|
| 195 |
+
{%- for block in message.content %}
|
| 196 |
+
{%- if block.type == "text" %}
|
| 197 |
+
{{- block.text }}
|
| 198 |
+
{%- elif block.type in ["image", "image_url"] %}
|
| 199 |
+
{{- '<|image_start|>' }}
|
| 200 |
+
{%- endif %}
|
| 201 |
+
{%- endfor %}
|
| 202 |
+
{%- endif %}
|
| 203 |
+
{{- '<|im_end|>\n' }}
|
| 204 |
+
{%- elif message.role == "tool" %}
|
| 205 |
+
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
| 206 |
+
{{- '<|im_start|>user\n' }}
|
| 207 |
+
{%- endif %}
|
| 208 |
+
{{- '<tool_response>\n' }}
|
| 209 |
+
{{- message.content }}
|
| 210 |
+
{{- '\n</tool_response>\n' }}
|
| 211 |
+
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
| 212 |
+
{{- '<|im_end|>\n' }}
|
| 213 |
+
{%- elif loop.last %}
|
| 214 |
+
{{- '<|im_end|>\n' }}
|
| 215 |
+
{%- endif %}
|
| 216 |
+
{%- else %}
|
| 217 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
|
| 218 |
+
{%- endif %}
|
| 219 |
+
{%- endfor %}
|
| 220 |
+
|
| 221 |
+
{%- if add_generation_prompt %}
|
| 222 |
+
{%- if enable_thinking %}
|
| 223 |
+
{{- '<|im_start|>assistant\n<think>\n' }}
|
| 224 |
+
{%- else %}
|
| 225 |
+
{{- '<|im_start|>assistant\n<think></think>' }}
|
| 226 |
+
{%- endif %}
|
| 227 |
+
{%- endif %}
|
chat_utils.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def add_gumbel_noise(logits, temperature):
|
| 7 |
+
'''
|
| 8 |
+
The Gumbel max is a method for sampling categorical distributions.
|
| 9 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
| 10 |
+
Thus, we use float64.
|
| 11 |
+
'''
|
| 12 |
+
if temperature == 0:
|
| 13 |
+
return logits
|
| 14 |
+
logits = logits.to(torch.float64)
|
| 15 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 16 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 17 |
+
return logits.exp() / gumbel_noise
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
|
| 21 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 22 |
+
x0 = torch.argmax(logits_with_noise, dim=-1)
|
| 23 |
+
|
| 24 |
+
if remasking == 'low_confidence':
|
| 25 |
+
# p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
p = F.softmax(logits, dim=-1)
|
| 27 |
+
x0_p = torch.squeeze(
|
| 28 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 29 |
+
elif remasking == 'top_p_margin':
|
| 30 |
+
# Compute probabilities
|
| 31 |
+
p = F.softmax(logits, dim=-1) # (B, L, V)
|
| 32 |
+
# Top-2 per position
|
| 33 |
+
top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
|
| 34 |
+
margin = top2[..., 0] - top2[..., 1] # (B, L)
|
| 35 |
+
|
| 36 |
+
# Normalize margin to [0,1] over MASKED positions per row
|
| 37 |
+
plus_inf = torch.full_like(margin, float('inf'))
|
| 38 |
+
minus_inf = torch.full_like(margin, float('-inf'))
|
| 39 |
+
masked_for_min = torch.where(mask_index, margin, plus_inf)
|
| 40 |
+
masked_for_max = torch.where(mask_index, margin, minus_inf)
|
| 41 |
+
row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
|
| 42 |
+
row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
|
| 43 |
+
denom = (row_max - row_min)
|
| 44 |
+
|
| 45 |
+
# If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
|
| 46 |
+
normalized = torch.zeros_like(margin)
|
| 47 |
+
nonzero = denom > 0
|
| 48 |
+
normalized = torch.where(
|
| 49 |
+
mask_index & nonzero,
|
| 50 |
+
(margin - row_min) / (denom + 1e-12),
|
| 51 |
+
normalized
|
| 52 |
+
)
|
| 53 |
+
normalized = torch.where(
|
| 54 |
+
mask_index & (~nonzero),
|
| 55 |
+
torch.ones_like(normalized),
|
| 56 |
+
normalized
|
| 57 |
+
)
|
| 58 |
+
x0_p = normalized # ∈ [0,1] on masked positions
|
| 59 |
+
elif remasking == 'random':
|
| 60 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 61 |
+
else:
|
| 62 |
+
raise NotImplementedError(remasking)
|
| 63 |
+
|
| 64 |
+
# Calculate negative entropy if requested
|
| 65 |
+
if neg_entropy:
|
| 66 |
+
# p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 67 |
+
p = F.softmax(logits, dim=-1)
|
| 68 |
+
epsilon = 1e-10
|
| 69 |
+
log_probs = torch.log(p + epsilon)
|
| 70 |
+
confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
|
| 71 |
+
else:
|
| 72 |
+
confidence_scores = x0_p
|
| 73 |
+
|
| 74 |
+
x0 = torch.where(mask_index, x0, x)
|
| 75 |
+
confidence = torch.where(mask_index, confidence_scores, -np.inf)
|
| 76 |
+
|
| 77 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 78 |
+
if threshold is not None:
|
| 79 |
+
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
|
| 80 |
+
# print(f'confidence: {confidence}')
|
| 81 |
+
for j in range(confidence.shape[0]):
|
| 82 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
|
| 83 |
+
transfer_index[j, select_index] = True
|
| 84 |
+
if threshold is not None:
|
| 85 |
+
for k in range(1, num_transfer_tokens[j]):
|
| 86 |
+
if confidence[j, select_index[k]] < threshold:
|
| 87 |
+
transfer_index[j, select_index[k]] = False
|
| 88 |
+
return x0, transfer_index
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_num_transfer_tokens(mask_index, steps: int):
|
| 92 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
| 93 |
+
base = mask_num // steps
|
| 94 |
+
remainder = mask_num % steps
|
| 95 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
| 96 |
+
for i in range(mask_num.size(0)):
|
| 97 |
+
num_transfer_tokens[i, : int(remainder[i])] += 1
|
| 98 |
+
return num_transfer_tokens
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def generate_with_prefix_cache_block_diff(
|
| 103 |
+
model,
|
| 104 |
+
prompt,
|
| 105 |
+
steps=128,
|
| 106 |
+
gen_length=128,
|
| 107 |
+
block_length=128,
|
| 108 |
+
temperature=0.,
|
| 109 |
+
remasking='low_confidence',
|
| 110 |
+
mask_id=126336,
|
| 111 |
+
threshold=None,
|
| 112 |
+
factor=None,
|
| 113 |
+
shift_logits=False,
|
| 114 |
+
neg_entropy=False,
|
| 115 |
+
causal_context=False,
|
| 116 |
+
pixel_values=None,
|
| 117 |
+
image_sizes=None,
|
| 118 |
+
eos_token_id=None,
|
| 119 |
+
):
|
| 120 |
+
dream_style=shift_logits
|
| 121 |
+
x_accum = prompt.clone()
|
| 122 |
+
|
| 123 |
+
assert gen_length % block_length == 0
|
| 124 |
+
num_blocks = gen_length // block_length
|
| 125 |
+
|
| 126 |
+
assert steps % num_blocks == 0
|
| 127 |
+
steps_per_block = steps // num_blocks
|
| 128 |
+
|
| 129 |
+
nfe = 0
|
| 130 |
+
|
| 131 |
+
if causal_context:
|
| 132 |
+
model_module = model.module if hasattr(model, "module") else model
|
| 133 |
+
for layer in model_module.encoder.layers:
|
| 134 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 135 |
+
layer.self_attn.diffusion_lm=False
|
| 136 |
+
|
| 137 |
+
# Compute KV cache for the prompt initially
|
| 138 |
+
# Pass pixel_values/image_sizes only for this first call (prompt contains image tokens)
|
| 139 |
+
output = model(prompt, use_cache=True, use_causal_mask=causal_context,
|
| 140 |
+
pixel_values=pixel_values, image_sizes=image_sizes)
|
| 141 |
+
past_key_values = output.past_key_values
|
| 142 |
+
|
| 143 |
+
if causal_context:
|
| 144 |
+
for layer in model_module.encoder.layers:
|
| 145 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 146 |
+
layer.self_attn.diffusion_lm=True
|
| 147 |
+
|
| 148 |
+
# For dream_style: store the "next token logit" of the context
|
| 149 |
+
next_logits_context = None
|
| 150 |
+
if dream_style:
|
| 151 |
+
next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
|
| 152 |
+
|
| 153 |
+
for num_block in range(num_blocks):
|
| 154 |
+
# Create a new block with mask tokens (no seeding)
|
| 155 |
+
mask_block = torch.ones(
|
| 156 |
+
(prompt.shape[0], block_length),
|
| 157 |
+
dtype=prompt.dtype,
|
| 158 |
+
device=prompt.device
|
| 159 |
+
) * mask_id
|
| 160 |
+
|
| 161 |
+
# Append the block of masks
|
| 162 |
+
x_accum = torch.cat([x_accum, mask_block], dim=1)
|
| 163 |
+
current_block_start = prompt.size(1) + num_block * block_length
|
| 164 |
+
block_slice = slice(current_block_start, current_block_start + block_length)
|
| 165 |
+
|
| 166 |
+
# Build the initial mask for this block
|
| 167 |
+
mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
|
| 168 |
+
|
| 169 |
+
# Precompute the transfer schedule for this block
|
| 170 |
+
if dream_style:
|
| 171 |
+
# still denoise *all* positions (0..Lb-1), since none are seeded
|
| 172 |
+
schedule_mask = mask_block_idx0
|
| 173 |
+
else:
|
| 174 |
+
schedule_mask = mask_block_idx0
|
| 175 |
+
|
| 176 |
+
num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
|
| 177 |
+
|
| 178 |
+
# Denoise the current block
|
| 179 |
+
for i in range(steps_per_block):
|
| 180 |
+
mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
|
| 181 |
+
if mask_block_idx.sum() == 0:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
nfe += 1
|
| 185 |
+
|
| 186 |
+
# Forward only the current noisy block using cached context
|
| 187 |
+
logits_block = model(
|
| 188 |
+
x_accum[:, block_slice],
|
| 189 |
+
past_key_values=past_key_values,
|
| 190 |
+
use_cache=False
|
| 191 |
+
).logits
|
| 192 |
+
|
| 193 |
+
if dream_style:
|
| 194 |
+
# Align logits so that each masked position has a predictor:
|
| 195 |
+
# prepend context-next logit, then use logits_block[:-1]
|
| 196 |
+
if block_length == 1:
|
| 197 |
+
logits_use = next_logits_context # (B, 1, V)
|
| 198 |
+
else:
|
| 199 |
+
logits_use = torch.cat(
|
| 200 |
+
[next_logits_context, logits_block[:, :-1, :]],
|
| 201 |
+
dim=1
|
| 202 |
+
) # (B, Lb, V)
|
| 203 |
+
|
| 204 |
+
mask_use = mask_block_idx # (B, Lb)
|
| 205 |
+
x_use = x_accum[:, block_slice] # (B, Lb)
|
| 206 |
+
|
| 207 |
+
x0, transfer_idx = get_transfer_index(
|
| 208 |
+
logits_use, temperature, remasking, mask_use, x_use,
|
| 209 |
+
num_transfer_tokens=num_transfer_tokens[:, i],
|
| 210 |
+
threshold=threshold, neg_entropy=neg_entropy
|
| 211 |
+
)
|
| 212 |
+
cur = x_accum[:, block_slice].clone()
|
| 213 |
+
cur[transfer_idx] = x0[transfer_idx]
|
| 214 |
+
x_accum[:, block_slice] = cur
|
| 215 |
+
|
| 216 |
+
else:
|
| 217 |
+
# non-AR (same-position) case
|
| 218 |
+
x0, transfer_idx = get_transfer_index(
|
| 219 |
+
logits_block, temperature, remasking, mask_block_idx,
|
| 220 |
+
x_accum[:, block_slice],
|
| 221 |
+
num_transfer_tokens=num_transfer_tokens[:, i],
|
| 222 |
+
threshold=threshold, neg_entropy=neg_entropy
|
| 223 |
+
)
|
| 224 |
+
cur = x_accum[:, block_slice].clone()
|
| 225 |
+
cur[transfer_idx] = x0[transfer_idx]
|
| 226 |
+
x_accum[:, block_slice] = cur
|
| 227 |
+
|
| 228 |
+
if eos_token_id is not None:
|
| 229 |
+
block_tokens = x_accum[:, block_slice] # (B, Lb)
|
| 230 |
+
eos_mask = (block_tokens == eos_token_id) # (B, Lb)
|
| 231 |
+
any_eos = eos_mask.any(dim=1) # (B,)
|
| 232 |
+
if any_eos.any():
|
| 233 |
+
after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb)
|
| 234 |
+
mask_before = (block_tokens == mask_id) & ~after_eos
|
| 235 |
+
if (any_eos & ~mask_before.any(dim=1)).any():
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
if causal_context:
|
| 239 |
+
for layer in model_module.encoder.layers:
|
| 240 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 241 |
+
layer.self_attn.diffusion_lm=False
|
| 242 |
+
|
| 243 |
+
# after block is fully denoised, update KV cache
|
| 244 |
+
output = model(
|
| 245 |
+
x_accum[:, block_slice],
|
| 246 |
+
past_key_values=past_key_values,
|
| 247 |
+
use_cache=True,
|
| 248 |
+
use_causal_mask=causal_context
|
| 249 |
+
)
|
| 250 |
+
past_key_values = output.past_key_values
|
| 251 |
+
|
| 252 |
+
if causal_context:
|
| 253 |
+
for layer in model_module.encoder.layers:
|
| 254 |
+
if hasattr(layer.self_attn, 'diffusion_lm'):
|
| 255 |
+
layer.self_attn.diffusion_lm=True
|
| 256 |
+
|
| 257 |
+
if dream_style and num_block < num_blocks - 1:
|
| 258 |
+
# refresh context-next logit for the next block
|
| 259 |
+
next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
|
| 260 |
+
|
| 261 |
+
if eos_token_id is not None:
|
| 262 |
+
gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far)
|
| 263 |
+
is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far)
|
| 264 |
+
has_eos = is_eos.any(dim=1) # (B,)
|
| 265 |
+
if has_eos.all():
|
| 266 |
+
return x_accum, nfe
|
| 267 |
+
|
| 268 |
+
# first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,)
|
| 269 |
+
# max_eos = first_eos_pos.max().item()
|
| 270 |
+
# return x_accum[:, : prompt.size(1) + max_eos + 1], nfe
|
| 271 |
+
|
| 272 |
+
return x_accum, nfe
|
config.json
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"ada_dlm_loss_ratio": null,
|
| 3 |
+
"ada_perm_ratio_global": null,
|
| 4 |
+
"ada_perm_ratio_per_block": null,
|
| 5 |
+
"adaptive_mask_rate": false,
|
| 6 |
+
"always_mask_im_end": false,
|
| 7 |
+
"ar_loss_weight": 1.0,
|
| 8 |
+
"architectures": [
|
| 9 |
+
"NemotronLabsDiffusionVLMModel"
|
| 10 |
+
],
|
| 11 |
+
"attention_bias": false,
|
| 12 |
+
"attention_dropout": 0.0,
|
| 13 |
+
"attn_implementation": null,
|
| 14 |
+
"auto_map": {
|
| 15 |
+
"AutoConfig": "configuration_nemotron_labs_diffusion_vlm.NemotronLabsDiffusionVLMConfig",
|
| 16 |
+
"AutoModel": "modeling_nemotron_labs_diffusion_vlm.NemotronLabsDiffusionVLMModel",
|
| 17 |
+
"AutoModelForCausalLM": "modeling_nemotron_labs_diffusion_vlm.NemotronLabsDiffusionVLMModel"
|
| 18 |
+
},
|
| 19 |
+
"block_size": 32,
|
| 20 |
+
"bos_token_id": 1,
|
| 21 |
+
"complementary_mask": true,
|
| 22 |
+
"diff_loss_weight": 0.5,
|
| 23 |
+
"dlm_arch": "encoder",
|
| 24 |
+
"dlm_loss_weight": 0.5,
|
| 25 |
+
"dlm_paradigm": "bidirectional",
|
| 26 |
+
"dlm_type": "llada",
|
| 27 |
+
"dp_varying_mask_ratio": false,
|
| 28 |
+
"dtype": "bfloat16",
|
| 29 |
+
"enforce_mask": false,
|
| 30 |
+
"eos_token_id": 11,
|
| 31 |
+
"global_loss_avg": true,
|
| 32 |
+
"head_dim": 128,
|
| 33 |
+
"hidden_act": "silu",
|
| 34 |
+
"hidden_size": 4096,
|
| 35 |
+
"im_end_token_id": 11,
|
| 36 |
+
"initializer_range": 0.02,
|
| 37 |
+
"intermediate_size": 14336,
|
| 38 |
+
"mask_token_id": 100,
|
| 39 |
+
"max_position_embeddings": 262144,
|
| 40 |
+
"mlp_bias": false,
|
| 41 |
+
"model_type": "nemotron_labs_diffusion_vlm",
|
| 42 |
+
"multi_sampling": null,
|
| 43 |
+
"multimodal_projector_bias": false,
|
| 44 |
+
"num_ar_layers": 0,
|
| 45 |
+
"num_attention_heads": 32,
|
| 46 |
+
"num_diffusion_layers": 0,
|
| 47 |
+
"num_hidden_layers": 34,
|
| 48 |
+
"num_key_value_heads": 8,
|
| 49 |
+
"num_skip_loss_tokens": 0,
|
| 50 |
+
"pad_token_id": 11,
|
| 51 |
+
"prefix_ratio": 0.8,
|
| 52 |
+
"projector_hidden_act": "gelu",
|
| 53 |
+
"random_length_prob": 0,
|
| 54 |
+
"rms_norm_eps": 1e-05,
|
| 55 |
+
"rope_parameters": {
|
| 56 |
+
"beta_fast": 32.0,
|
| 57 |
+
"beta_slow": 1.0,
|
| 58 |
+
"factor": 16.0,
|
| 59 |
+
"llama_4_scaling_beta": 0.1,
|
| 60 |
+
"mscale": 1.0,
|
| 61 |
+
"mscale_all_dim": 1.0,
|
| 62 |
+
"original_max_position_embeddings": 16384,
|
| 63 |
+
"rope_theta": 1000000.0,
|
| 64 |
+
"rope_type": "yarn",
|
| 65 |
+
"type": "yarn"
|
| 66 |
+
},
|
| 67 |
+
"rope_scaling": {
|
| 68 |
+
"beta_fast": 32.0,
|
| 69 |
+
"beta_slow": 1.0,
|
| 70 |
+
"factor": 16.0,
|
| 71 |
+
"llama_4_scaling_beta": 0.1,
|
| 72 |
+
"mscale": 1.0,
|
| 73 |
+
"mscale_all_dim": 1.0,
|
| 74 |
+
"original_max_position_embeddings": 16384,
|
| 75 |
+
"rope_theta": 1000000.0,
|
| 76 |
+
"rope_type": "yarn",
|
| 77 |
+
"type": "yarn"
|
| 78 |
+
},
|
| 79 |
+
"rope_theta": 1000000.0,
|
| 80 |
+
"sliding_window": null,
|
| 81 |
+
"spatial_merge_size": 2,
|
| 82 |
+
"tie_word_embeddings": false,
|
| 83 |
+
"tok_mask_half_life_ratio": null,
|
| 84 |
+
"transformers_version": "4.57.1",
|
| 85 |
+
"use_cache": false,
|
| 86 |
+
"vision_config": {
|
| 87 |
+
"attention_dropout": 0.0,
|
| 88 |
+
"head_dim": 64,
|
| 89 |
+
"hidden_act": "silu",
|
| 90 |
+
"hidden_size": 1024,
|
| 91 |
+
"image_size": 1540,
|
| 92 |
+
"initializer_range": 0.02,
|
| 93 |
+
"intermediate_size": 4096,
|
| 94 |
+
"model_type": "pixtral",
|
| 95 |
+
"num_attention_heads": 16,
|
| 96 |
+
"num_channels": 3,
|
| 97 |
+
"num_hidden_layers": 24,
|
| 98 |
+
"patch_size": 14,
|
| 99 |
+
"rope_parameters": {
|
| 100 |
+
"rope_theta": 10000.0,
|
| 101 |
+
"rope_type": "default"
|
| 102 |
+
}
|
| 103 |
+
},
|
| 104 |
+
"vision_feature_layer": -1,
|
| 105 |
+
"vocab_size": 131073
|
| 106 |
+
}
|
configuration_nemotron_labs_diffusion_vlm.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Nemotron-Labs Diffusion VLM model configuration"""
|
| 16 |
+
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NemotronLabsDiffusionVLMConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
|
| 28 |
+
It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
|
| 29 |
+
|
| 30 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 35 |
+
Vocabulary size of the Ministral model.
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 37 |
+
Dimension of the hidden representations.
|
| 38 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
| 39 |
+
Dimension of the MLP representations.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 34):
|
| 41 |
+
Number of hidden layers in the Transformer decoder.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 43 |
+
Number of attention heads for each attention layer.
|
| 44 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 45 |
+
Number of key_value heads for Grouped Query Attention.
|
| 46 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 47 |
+
The attention head dimension.
|
| 48 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 49 |
+
The non-linear activation function.
|
| 50 |
+
max_position_embeddings (`int`, *optional*, defaults to 262144):
|
| 51 |
+
The maximum sequence length.
|
| 52 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 53 |
+
The standard deviation of the truncated_normal_initializer.
|
| 54 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 55 |
+
The epsilon used by the rms normalization layers.
|
| 56 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether or not the model should return the last key/values attentions.
|
| 58 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether the model's input and output word embeddings should be tied.
|
| 60 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 61 |
+
The base period of the RoPE embeddings.
|
| 62 |
+
rope_parameters (`Dict`, *optional*):
|
| 63 |
+
Dictionary containing the scaling configuration for the RoPE embeddings.
|
| 64 |
+
Default uses YaRN scaling with factor=16, original_max_position_embeddings=16384.
|
| 65 |
+
attention_bias (`bool`, defaults to `False`):
|
| 66 |
+
Whether to use a bias in the query, key, value and output projection layers.
|
| 67 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 68 |
+
The dropout ratio for the attention probabilities.
|
| 69 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 70 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers.
|
| 71 |
+
sliding_window (`int`, *optional*, defaults to None):
|
| 72 |
+
Sliding window attention size.
|
| 73 |
+
mask_token_id (`int`, *optional*, defaults to -1):
|
| 74 |
+
Token ID for masking in diffusion.
|
| 75 |
+
dlm_type (`str`, *optional*, defaults to 'llada'):
|
| 76 |
+
Type of diffusion language model ('llada', 'dream').
|
| 77 |
+
random_length_prob (`float`, *optional*):
|
| 78 |
+
Probability of using random lengths during training.
|
| 79 |
+
num_ar_layers (`int`, *optional*, defaults to 0):
|
| 80 |
+
Number of autoregressive layers.
|
| 81 |
+
num_diffusion_layers (`int`, *optional*, defaults to 0):
|
| 82 |
+
Number of diffusion layers.
|
| 83 |
+
diff_loss_weight (`float`, *optional*, defaults to 1):
|
| 84 |
+
Weight for diffusion loss.
|
| 85 |
+
enforce_mask (`bool`, *optional*, defaults to False):
|
| 86 |
+
Whether to enforce masking.
|
| 87 |
+
prefix_ratio (`float`, *optional*, defaults to 0.8):
|
| 88 |
+
Ratio for prefix in prefix_bidirectional mode.
|
| 89 |
+
dlm_paradigm (`str`, *optional*, defaults to 'bidirectional'):
|
| 90 |
+
Paradigm for diffusion ('bidirectional', 'autoregressive', 'prefix_bidirectional', 'efficient_block_diff', 'block_diff', 'sbd_block_diff').
|
| 91 |
+
dlm_arch (`str`, *optional*, defaults to 'encoder'):
|
| 92 |
+
Architecture type ('encoder', 'encoder_decoder').
|
| 93 |
+
block_size (`int`, *optional*, defaults to 32):
|
| 94 |
+
Block size for block diffusion paradigms.
|
| 95 |
+
tok_mask_half_life_ratio (`float`, *optional*):
|
| 96 |
+
Half-life ratio for token masking.
|
| 97 |
+
adaptive_mask_rate (`bool`, *optional*, defaults to False):
|
| 98 |
+
Whether to use adaptive mask rate.
|
| 99 |
+
multi_sampling (`int`, *optional*):
|
| 100 |
+
Number of samples for multi-sampling.
|
| 101 |
+
num_skip_loss_tokens (`int`, *optional*, defaults to 0):
|
| 102 |
+
Number of tokens to skip in loss calculation.
|
| 103 |
+
dlm_loss_weight (`float`, *optional*):
|
| 104 |
+
Weight for diffusion LM loss.
|
| 105 |
+
ar_loss_weight (`float`, *optional*, defaults to 1.0):
|
| 106 |
+
Weight for autoregressive loss in sbd_block_diff paradigm. Use 10000 to only use AR loss.
|
| 107 |
+
global_loss_avg (`bool`, *optional*, defaults to False):
|
| 108 |
+
Whether to use global loss average.
|
| 109 |
+
dp_varying_mask_ratio (`bool`, *optional*, defaults to False):
|
| 110 |
+
Whether to use varying mask ratio for each DP rank during sampling.
|
| 111 |
+
ada_perm_ratio_per_block (`float`, *optional*):
|
| 112 |
+
Adaptive permutation ratio for each block.
|
| 113 |
+
ada_perm_ratio_global (`float`, *optional*):
|
| 114 |
+
Adaptive permutation ratio for global.
|
| 115 |
+
complementary_mask (`bool`, *optional*, defaults to False):
|
| 116 |
+
Whether to use complementary masking (mask + inverse mask).
|
| 117 |
+
always_mask_im_end (`bool`, *optional*, defaults to False):
|
| 118 |
+
Whether to always mask im_end tokens.
|
| 119 |
+
im_end_token_id (`int`, *optional*, defaults to 11):
|
| 120 |
+
Token ID for im_end in always_mask_im_end.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
model_type = "nemotron_labs_diffusion_vlm"
|
| 124 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 125 |
+
|
| 126 |
+
# Default tensor parallel plan for base model `Ministral`
|
| 127 |
+
base_model_tp_plan = {
|
| 128 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 129 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 130 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 131 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 132 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 133 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 134 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 135 |
+
}
|
| 136 |
+
base_model_pp_plan = {
|
| 137 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 138 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 139 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
vocab_size=131072,
|
| 145 |
+
hidden_size=4096,
|
| 146 |
+
intermediate_size=14336,
|
| 147 |
+
num_hidden_layers=34,
|
| 148 |
+
num_attention_heads=32,
|
| 149 |
+
num_key_value_heads=8,
|
| 150 |
+
head_dim=128,
|
| 151 |
+
hidden_act="silu",
|
| 152 |
+
max_position_embeddings=262144,
|
| 153 |
+
initializer_range=0.02,
|
| 154 |
+
rms_norm_eps=1e-05,
|
| 155 |
+
use_cache=True,
|
| 156 |
+
pad_token_id=None,
|
| 157 |
+
bos_token_id=1,
|
| 158 |
+
eos_token_id=2,
|
| 159 |
+
tie_word_embeddings=False,
|
| 160 |
+
rope_theta=1000000.0,
|
| 161 |
+
rope_parameters=None,
|
| 162 |
+
rope_scaling=None,
|
| 163 |
+
attention_bias=False,
|
| 164 |
+
attention_dropout=0.0,
|
| 165 |
+
mlp_bias=False,
|
| 166 |
+
sliding_window=None,
|
| 167 |
+
attn_implementation="sdpa",
|
| 168 |
+
mask_token_id=-1,
|
| 169 |
+
dlm_type='llada',
|
| 170 |
+
random_length_prob=None,
|
| 171 |
+
num_ar_layers=0,
|
| 172 |
+
num_diffusion_layers=0,
|
| 173 |
+
diff_loss_weight=1,
|
| 174 |
+
enforce_mask=False,
|
| 175 |
+
prefix_ratio=0.8,
|
| 176 |
+
dlm_paradigm='bidirectional',
|
| 177 |
+
dlm_arch='encoder',
|
| 178 |
+
block_size=32,
|
| 179 |
+
tok_mask_half_life_ratio=None,
|
| 180 |
+
adaptive_mask_rate=False,
|
| 181 |
+
multi_sampling=None,
|
| 182 |
+
num_skip_loss_tokens=0,
|
| 183 |
+
dlm_loss_weight=None,
|
| 184 |
+
ar_loss_weight=1.0,
|
| 185 |
+
global_loss_avg=False,
|
| 186 |
+
dp_varying_mask_ratio=False,
|
| 187 |
+
ada_perm_ratio_per_block=None,
|
| 188 |
+
ada_perm_ratio_global=None,
|
| 189 |
+
ada_dlm_loss_ratio=None,
|
| 190 |
+
complementary_mask=False,
|
| 191 |
+
always_mask_im_end=False,
|
| 192 |
+
im_end_token_id=11,
|
| 193 |
+
**kwargs,
|
| 194 |
+
):
|
| 195 |
+
self.vocab_size = vocab_size
|
| 196 |
+
self.max_position_embeddings = max_position_embeddings
|
| 197 |
+
self.hidden_size = hidden_size
|
| 198 |
+
self.intermediate_size = intermediate_size
|
| 199 |
+
self.num_hidden_layers = num_hidden_layers
|
| 200 |
+
self.num_attention_heads = num_attention_heads
|
| 201 |
+
|
| 202 |
+
# for backward compatibility
|
| 203 |
+
if num_key_value_heads is None:
|
| 204 |
+
num_key_value_heads = num_attention_heads
|
| 205 |
+
|
| 206 |
+
self.num_key_value_heads = num_key_value_heads
|
| 207 |
+
self.head_dim = head_dim
|
| 208 |
+
self.hidden_act = hidden_act
|
| 209 |
+
self.initializer_range = initializer_range
|
| 210 |
+
self.rms_norm_eps = rms_norm_eps
|
| 211 |
+
self.use_cache = use_cache
|
| 212 |
+
self.rope_theta = rope_theta
|
| 213 |
+
self.rope_parameters = rope_parameters
|
| 214 |
+
self.rope_scaling = rope_scaling
|
| 215 |
+
self.attention_bias = attention_bias
|
| 216 |
+
self.attention_dropout = attention_dropout
|
| 217 |
+
self.mlp_bias = mlp_bias
|
| 218 |
+
self.sliding_window = sliding_window
|
| 219 |
+
|
| 220 |
+
rope_config_validation(self)
|
| 221 |
+
|
| 222 |
+
self.attn_implementation = attn_implementation
|
| 223 |
+
|
| 224 |
+
self.mask_token_id = mask_token_id
|
| 225 |
+
self.dlm_type = dlm_type
|
| 226 |
+
self.random_length_prob = random_length_prob
|
| 227 |
+
self.num_ar_layers = num_ar_layers
|
| 228 |
+
self.num_diffusion_layers = num_diffusion_layers
|
| 229 |
+
self.diff_loss_weight = diff_loss_weight
|
| 230 |
+
self.enforce_mask = enforce_mask
|
| 231 |
+
self.prefix_ratio = prefix_ratio
|
| 232 |
+
self.dlm_paradigm = dlm_paradigm
|
| 233 |
+
self.dlm_arch = dlm_arch
|
| 234 |
+
self.block_size = block_size
|
| 235 |
+
self.tok_mask_half_life_ratio = tok_mask_half_life_ratio
|
| 236 |
+
self.adaptive_mask_rate = adaptive_mask_rate
|
| 237 |
+
self.multi_sampling = multi_sampling
|
| 238 |
+
self.num_skip_loss_tokens = num_skip_loss_tokens
|
| 239 |
+
self.dlm_loss_weight = dlm_loss_weight
|
| 240 |
+
self.ar_loss_weight = ar_loss_weight
|
| 241 |
+
self.global_loss_avg = global_loss_avg
|
| 242 |
+
self.dp_varying_mask_ratio = dp_varying_mask_ratio
|
| 243 |
+
self.ada_perm_ratio_per_block = ada_perm_ratio_per_block
|
| 244 |
+
self.ada_perm_ratio_global = ada_perm_ratio_global
|
| 245 |
+
self.ada_dlm_loss_ratio = ada_dlm_loss_ratio
|
| 246 |
+
self.complementary_mask = complementary_mask
|
| 247 |
+
self.always_mask_im_end = always_mask_im_end
|
| 248 |
+
self.im_end_token_id = im_end_token_id
|
| 249 |
+
super().__init__(
|
| 250 |
+
pad_token_id=pad_token_id,
|
| 251 |
+
bos_token_id=bos_token_id,
|
| 252 |
+
eos_token_id=eos_token_id,
|
| 253 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 254 |
+
**kwargs,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
__all__ = ["NemotronLabsDiffusionVLMConfig"]
|
| 259 |
+
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
11,
|
| 6 |
+
2
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 11,
|
| 9 |
+
"transformers_version": "4.57.1",
|
| 10 |
+
"use_cache": false
|
| 11 |
+
}
|
image_processing.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image processing utilities for Nemotron-Diffusion-Exp-Ministral-8B-Instruct (final-template).
|
| 3 |
+
|
| 4 |
+
Implements image token expansion and pixel value preprocessing,
|
| 5 |
+
faithfully ported from mistral_common.tokens.tokenizers.image.ImageEncoder
|
| 6 |
+
to ensure identical image sizing and token counts.
|
| 7 |
+
|
| 8 |
+
Special token mapping (final-template version):
|
| 9 |
+
<|image_start|> (id=18) = [IMG_START] image start marker
|
| 10 |
+
<|image_pad|> (id=19) = [IMG] image pad token (one per merged patch)
|
| 11 |
+
<|image_break|> (id=20) = [IMG_BREAK] image row break
|
| 12 |
+
<|image_end|> (id=21) = [IMG_END] image end marker
|
| 13 |
+
|
| 14 |
+
After expansion, each image placeholder becomes:
|
| 15 |
+
[IMG_START] ([IMG]*W [IMG_BREAK]) * (H-1) [IMG]*W [IMG_END]
|
| 16 |
+
|
| 17 |
+
where W = width_tokens, H = height_tokens (computed via ceiling division
|
| 18 |
+
on the original image dims, matching mistral_common exactly).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import cv2
|
| 26 |
+
import numpy as np
|
| 27 |
+
import requests
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Token strings (must match tokenizer_config.json) ──────────────────────────
|
| 33 |
+
IMG_START_TOKEN = "<|image_start|>" # id = 18
|
| 34 |
+
IMG_PAD_TOKEN = "<|image_pad|>" # id = 19
|
| 35 |
+
IMG_BREAK_TOKEN = "<|image_break|>" # id = 20
|
| 36 |
+
IMG_END_TOKEN = "<|image_end|>" # id = 21
|
| 37 |
+
|
| 38 |
+
# ── Token IDs ─────────────────────────────────────────────────────────────────
|
| 39 |
+
IMG_START_ID = 18
|
| 40 |
+
IMG_PAD_ID = 19
|
| 41 |
+
IMG_BREAK_ID = 20
|
| 42 |
+
IMG_END_ID = 21
|
| 43 |
+
|
| 44 |
+
# ── Default config (from config.json / processor_config.json) ─────────────────
|
| 45 |
+
DEFAULT_PATCH_SIZE = 14
|
| 46 |
+
DEFAULT_SPATIAL_MERGE_SIZE = 2
|
| 47 |
+
DEFAULT_MAX_IMAGE_SIZE = 1400 # longest edge
|
| 48 |
+
# Allow override via environment variable (e.g. from run_all_benchmarks.sh)
|
| 49 |
+
_env_max = os.environ.get("DEFAULT_MAX_IMAGE_SIZE")
|
| 50 |
+
if _env_max is not None and str(_env_max).strip():
|
| 51 |
+
try:
|
| 52 |
+
DEFAULT_MAX_IMAGE_SIZE = int(_env_max)
|
| 53 |
+
except ValueError:
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) # RGB
|
| 57 |
+
DATASET_STD = (0.26862954, 0.26130258, 0.27577711) # RGB
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 61 |
+
# Image loading (mirrors mistral_common.tokens.tokenizers.image)
|
| 62 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 63 |
+
|
| 64 |
+
def _convert_to_rgb(image: Image.Image) -> Image.Image:
|
| 65 |
+
"""Convert PIL image to RGB; transparent backgrounds become white."""
|
| 66 |
+
if image.mode == "RGB":
|
| 67 |
+
return image
|
| 68 |
+
if image.mode != "RGBA":
|
| 69 |
+
image = image.convert("RGBA")
|
| 70 |
+
white_bg = Image.new("RGBA", image.size, "WHITE")
|
| 71 |
+
white_bg.paste(image, (0, 0), image)
|
| 72 |
+
return white_bg.convert("RGB")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_image(source: Union[str, Image.Image]) -> Image.Image:
|
| 76 |
+
"""Load an image from a URL, local file path, or PIL Image."""
|
| 77 |
+
if isinstance(source, Image.Image):
|
| 78 |
+
return source
|
| 79 |
+
if source.startswith(("http://", "https://")):
|
| 80 |
+
resp = requests.get(source, stream=True, timeout=30)
|
| 81 |
+
resp.raise_for_status()
|
| 82 |
+
return Image.open(BytesIO(resp.content))
|
| 83 |
+
return Image.open(source)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 87 |
+
# Core logic — ported from mistral_common ImageEncoder
|
| 88 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 89 |
+
|
| 90 |
+
def _image_to_num_tokens(
|
| 91 |
+
img: Image.Image,
|
| 92 |
+
image_patch_size: int = DEFAULT_PATCH_SIZE,
|
| 93 |
+
max_image_size: int = DEFAULT_MAX_IMAGE_SIZE,
|
| 94 |
+
spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE,
|
| 95 |
+
) -> Tuple[int, int]:
|
| 96 |
+
"""
|
| 97 |
+
Compute (width_tokens, height_tokens) for a given image — identical to
|
| 98 |
+
``mistral_common.tokens.tokenizers.image.ImageEncoder._image_to_num_tokens``.
|
| 99 |
+
"""
|
| 100 |
+
w, h = img.size # PIL: (W, H)
|
| 101 |
+
ratio = max(h / max_image_size, w / max_image_size)
|
| 102 |
+
if ratio > 1:
|
| 103 |
+
w = round(w / ratio)
|
| 104 |
+
h = round(h / ratio)
|
| 105 |
+
|
| 106 |
+
width_tokens = (w - 1) // (image_patch_size * spatial_merge_size) + 1
|
| 107 |
+
height_tokens = (h - 1) // (image_patch_size * spatial_merge_size) + 1
|
| 108 |
+
return width_tokens, height_tokens
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def transform_image(
|
| 112 |
+
image: Image.Image,
|
| 113 |
+
new_size: Tuple[int, int],
|
| 114 |
+
mean: Tuple[float, ...] = DATASET_MEAN,
|
| 115 |
+
std: Tuple[float, ...] = DATASET_STD,
|
| 116 |
+
) -> np.ndarray:
|
| 117 |
+
"""
|
| 118 |
+
Resize + normalise — identical to
|
| 119 |
+
``mistral_common.tokens.tokenizers.image.transform_image``.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
image: PIL Image (any mode).
|
| 123 |
+
new_size: Target (W, H) — cv2 convention.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
np.ndarray of shape (C, H, W), float32, normalised.
|
| 127 |
+
"""
|
| 128 |
+
np_image = cv2.resize(
|
| 129 |
+
np.array(_convert_to_rgb(image), dtype=np.float32),
|
| 130 |
+
new_size,
|
| 131 |
+
interpolation=cv2.INTER_CUBIC,
|
| 132 |
+
)
|
| 133 |
+
np_image = np_image / 255.0
|
| 134 |
+
np_image = (np_image - np.array(mean, dtype=np.float32)) / np.array(std, dtype=np.float32)
|
| 135 |
+
return np_image.transpose(2, 0, 1)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def encode_image(
|
| 139 |
+
image: Image.Image,
|
| 140 |
+
image_patch_size: int = DEFAULT_PATCH_SIZE,
|
| 141 |
+
max_image_size: int = DEFAULT_MAX_IMAGE_SIZE,
|
| 142 |
+
spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE,
|
| 143 |
+
) -> Tuple[int, int, np.ndarray]:
|
| 144 |
+
"""
|
| 145 |
+
Compute token dimensions **and** preprocessed pixel array for one image.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
(width_tokens, height_tokens, pixel_array)
|
| 149 |
+
where pixel_array has shape (C, H, W).
|
| 150 |
+
"""
|
| 151 |
+
w_tok, h_tok = _image_to_num_tokens(
|
| 152 |
+
image, image_patch_size, max_image_size, spatial_merge_size,
|
| 153 |
+
)
|
| 154 |
+
assert w_tok > 0 and h_tok > 0
|
| 155 |
+
|
| 156 |
+
new_w = w_tok * image_patch_size * spatial_merge_size
|
| 157 |
+
new_h = h_tok * image_patch_size * spatial_merge_size
|
| 158 |
+
processed = transform_image(image, (new_w, new_h)) # cv2: (W, H)
|
| 159 |
+
|
| 160 |
+
return w_tok, h_tok, processed
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 164 |
+
# Token string expansion
|
| 165 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 166 |
+
|
| 167 |
+
def build_image_token_str(w_tokens: int, h_tokens: int) -> str:
|
| 168 |
+
"""
|
| 169 |
+
Build the expanded image-token string for one image.
|
| 170 |
+
|
| 171 |
+
Pattern:
|
| 172 |
+
[IMG_START]
|
| 173 |
+
([IMG]*W [IMG_BREAK]) * (H-1)
|
| 174 |
+
[IMG]*W [IMG_END]
|
| 175 |
+
"""
|
| 176 |
+
row = IMG_PAD_TOKEN * w_tokens + IMG_BREAK_TOKEN
|
| 177 |
+
body = row * h_tokens
|
| 178 |
+
body = body[: -len(IMG_BREAK_TOKEN)] + IMG_END_TOKEN
|
| 179 |
+
|
| 180 |
+
return IMG_START_TOKEN + body
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 184 |
+
# Extract image sources from OpenAI-style messages
|
| 185 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 186 |
+
|
| 187 |
+
def _extract_image_sources(messages: List[Dict[str, Any]]) -> List[str]:
|
| 188 |
+
"""Walk through OpenAI-style messages and collect image URLs / paths."""
|
| 189 |
+
sources: List[str] = []
|
| 190 |
+
for msg in messages:
|
| 191 |
+
content = msg.get("content", "")
|
| 192 |
+
if not isinstance(content, list):
|
| 193 |
+
continue
|
| 194 |
+
for block in content:
|
| 195 |
+
btype = block.get("type")
|
| 196 |
+
if btype == "image_url":
|
| 197 |
+
url_obj = block.get("image_url", {})
|
| 198 |
+
sources.append(url_obj.get("url", ""))
|
| 199 |
+
elif btype == "image":
|
| 200 |
+
for key in ("url", "path", "image"):
|
| 201 |
+
if key in block:
|
| 202 |
+
sources.append(block[key])
|
| 203 |
+
break
|
| 204 |
+
return sources
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 208 |
+
# Public API
|
| 209 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 210 |
+
|
| 211 |
+
def process_messages(
|
| 212 |
+
tokenizer,
|
| 213 |
+
messages: List[Dict[str, Any]],
|
| 214 |
+
*,
|
| 215 |
+
patch_size: int = DEFAULT_PATCH_SIZE,
|
| 216 |
+
spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE,
|
| 217 |
+
max_image_size: int = DEFAULT_MAX_IMAGE_SIZE,
|
| 218 |
+
return_tensors: str = "pt",
|
| 219 |
+
add_generation_prompt: bool = False,
|
| 220 |
+
enable_thinking: bool = True,
|
| 221 |
+
) -> Dict[str, Any]:
|
| 222 |
+
"""
|
| 223 |
+
Process chat messages with optional images — drop-in replacement for
|
| 224 |
+
``MistralCommonBackend.apply_chat_template(return_dict=True)``.
|
| 225 |
+
|
| 226 |
+
Steps:
|
| 227 |
+
1. Render Jinja chat template → prompt with ``<|image_start|>`` placeholders.
|
| 228 |
+
2. For each image:
|
| 229 |
+
a. Load image.
|
| 230 |
+
b. Compute token dims via ceiling division (matching mistral_common).
|
| 231 |
+
c. Resize to token-aligned dimensions with cv2 INTER_CUBIC.
|
| 232 |
+
d. Normalise pixels.
|
| 233 |
+
e. Replace the next ``<|image_start|>`` placeholder with the expanded
|
| 234 |
+
token sequence.
|
| 235 |
+
3. Tokenize the expanded prompt.
|
| 236 |
+
4. Return dict with ``input_ids`` (and ``pixel_values`` / ``image_sizes``
|
| 237 |
+
if images are present).
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
enable_thinking: When True (default), the generation prompt opens a
|
| 241 |
+
``<think>`` block for chain-of-thought reasoning. When False,
|
| 242 |
+
an empty ``<think></think>`` is emitted so the model skips
|
| 243 |
+
the thinking phase.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
dict with keys:
|
| 247 |
+
input_ids : LongTensor (1, seq_len)
|
| 248 |
+
pixel_values : FloatTensor (N, 3, H, W) – only when images present
|
| 249 |
+
image_sizes : list of (H, W) tuples – only when images present
|
| 250 |
+
"""
|
| 251 |
+
# ── 1. Extract image sources ──────────────────────────────────────────
|
| 252 |
+
image_sources = _extract_image_sources(messages)
|
| 253 |
+
|
| 254 |
+
# ── 2. Render chat template (produces <|image_start|> placeholders) ──
|
| 255 |
+
prompt: str = tokenizer.apply_chat_template(
|
| 256 |
+
messages,
|
| 257 |
+
tokenize=False,
|
| 258 |
+
add_generation_prompt=add_generation_prompt,
|
| 259 |
+
enable_thinking=enable_thinking,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# ── 3. Expand each placeholder & preprocess images ────────────────────
|
| 263 |
+
pixel_list: List[np.ndarray] = []
|
| 264 |
+
image_sizes: List[Tuple[int, int]] = []
|
| 265 |
+
|
| 266 |
+
for src in image_sources:
|
| 267 |
+
pil_img = load_image(src)
|
| 268 |
+
|
| 269 |
+
w_tok, h_tok, pixels = encode_image(
|
| 270 |
+
pil_img, patch_size, max_image_size, spatial_merge_size,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
expanded = build_image_token_str(w_tok, h_tok)
|
| 274 |
+
prompt = prompt.replace(IMG_START_TOKEN, expanded, 1)
|
| 275 |
+
|
| 276 |
+
pixel_list.append(pixels)
|
| 277 |
+
final_h = h_tok * patch_size * spatial_merge_size
|
| 278 |
+
final_w = w_tok * patch_size * spatial_merge_size
|
| 279 |
+
image_sizes.append((final_h, final_w))
|
| 280 |
+
|
| 281 |
+
# ── 4. Tokenize ──────────────────────────────────────────────────────
|
| 282 |
+
if return_tensors == "pt":
|
| 283 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
| 284 |
+
else:
|
| 285 |
+
input_ids = tokenizer(prompt).input_ids
|
| 286 |
+
|
| 287 |
+
result: Dict[str, Any] = {"input_ids": input_ids}
|
| 288 |
+
|
| 289 |
+
if pixel_list:
|
| 290 |
+
if return_tensors == "pt":
|
| 291 |
+
result["pixel_values"] = torch.from_numpy(np.stack(pixel_list))
|
| 292 |
+
else:
|
| 293 |
+
result["pixel_values"] = np.stack(pixel_list)
|
| 294 |
+
result["image_sizes"] = image_sizes
|
| 295 |
+
|
| 296 |
+
return result
|
model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6bdc8ac3f1baef94a6a6fe6c5290031495e45ab18a7884f210dd8c157b33582
|
| 3 |
+
size 4984302088
|
model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e4c3f9cac913ee79685bb55f4eee4c7220ae03ad50bfe83131401a0df023706
|
| 3 |
+
size 4999802904
|
model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:562159c86be4c9cdde7250b5b78f4636746fd1e82c9e2e10310f74d5696c3503
|
| 3 |
+
size 4915916376
|
model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7323f1fdacbf54bee85a2b4b074f88767c8d51754eb12495ce48f3f152048276
|
| 3 |
+
size 2936115968
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_parameters": 8918034432,
|
| 4 |
+
"total_size": 17836068864
|
| 5 |
+
},
|
| 6 |
+
"weight_map": {
|
| 7 |
+
"diffusion_head.weight": "model-00004-of-00004.safetensors",
|
| 8 |
+
"encoder.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
| 9 |
+
"encoder.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 10 |
+
"encoder.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 11 |
+
"encoder.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 12 |
+
"encoder.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 13 |
+
"encoder.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 14 |
+
"encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 15 |
+
"encoder.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 16 |
+
"encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 17 |
+
"encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 18 |
+
"encoder.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 19 |
+
"encoder.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 20 |
+
"encoder.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 21 |
+
"encoder.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 22 |
+
"encoder.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 23 |
+
"encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 24 |
+
"encoder.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 25 |
+
"encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 26 |
+
"encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 27 |
+
"encoder.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 28 |
+
"encoder.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 29 |
+
"encoder.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 30 |
+
"encoder.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 31 |
+
"encoder.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 32 |
+
"encoder.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 33 |
+
"encoder.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 34 |
+
"encoder.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 35 |
+
"encoder.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 36 |
+
"encoder.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 37 |
+
"encoder.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 38 |
+
"encoder.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 39 |
+
"encoder.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 40 |
+
"encoder.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 41 |
+
"encoder.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 42 |
+
"encoder.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 43 |
+
"encoder.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 44 |
+
"encoder.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 45 |
+
"encoder.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 46 |
+
"encoder.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 47 |
+
"encoder.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 48 |
+
"encoder.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 49 |
+
"encoder.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 50 |
+
"encoder.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 51 |
+
"encoder.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 52 |
+
"encoder.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 53 |
+
"encoder.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 54 |
+
"encoder.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 55 |
+
"encoder.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 56 |
+
"encoder.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 57 |
+
"encoder.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 58 |
+
"encoder.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 59 |
+
"encoder.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 60 |
+
"encoder.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 61 |
+
"encoder.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 62 |
+
"encoder.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 63 |
+
"encoder.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 64 |
+
"encoder.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 65 |
+
"encoder.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 66 |
+
"encoder.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 67 |
+
"encoder.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 68 |
+
"encoder.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 69 |
+
"encoder.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 70 |
+
"encoder.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 71 |
+
"encoder.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 72 |
+
"encoder.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 73 |
+
"encoder.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 74 |
+
"encoder.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 75 |
+
"encoder.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 76 |
+
"encoder.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 77 |
+
"encoder.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 78 |
+
"encoder.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 79 |
+
"encoder.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 80 |
+
"encoder.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 81 |
+
"encoder.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 82 |
+
"encoder.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 83 |
+
"encoder.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 84 |
+
"encoder.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 85 |
+
"encoder.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 86 |
+
"encoder.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 87 |
+
"encoder.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 88 |
+
"encoder.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 89 |
+
"encoder.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 90 |
+
"encoder.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 91 |
+
"encoder.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 92 |
+
"encoder.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 93 |
+
"encoder.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 94 |
+
"encoder.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 95 |
+
"encoder.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 96 |
+
"encoder.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 97 |
+
"encoder.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 98 |
+
"encoder.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 99 |
+
"encoder.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 100 |
+
"encoder.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 101 |
+
"encoder.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 102 |
+
"encoder.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 103 |
+
"encoder.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 104 |
+
"encoder.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 105 |
+
"encoder.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 106 |
+
"encoder.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 107 |
+
"encoder.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 108 |
+
"encoder.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 109 |
+
"encoder.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 110 |
+
"encoder.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 111 |
+
"encoder.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 112 |
+
"encoder.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 113 |
+
"encoder.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 114 |
+
"encoder.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 115 |
+
"encoder.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 116 |
+
"encoder.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 117 |
+
"encoder.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 118 |
+
"encoder.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 119 |
+
"encoder.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 120 |
+
"encoder.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 121 |
+
"encoder.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 122 |
+
"encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 123 |
+
"encoder.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 124 |
+
"encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 125 |
+
"encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 126 |
+
"encoder.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 127 |
+
"encoder.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 128 |
+
"encoder.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 129 |
+
"encoder.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 130 |
+
"encoder.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 131 |
+
"encoder.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 132 |
+
"encoder.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 133 |
+
"encoder.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 134 |
+
"encoder.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 135 |
+
"encoder.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 136 |
+
"encoder.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 137 |
+
"encoder.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 138 |
+
"encoder.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 139 |
+
"encoder.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 140 |
+
"encoder.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 141 |
+
"encoder.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 142 |
+
"encoder.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 143 |
+
"encoder.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 144 |
+
"encoder.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 145 |
+
"encoder.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 146 |
+
"encoder.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 147 |
+
"encoder.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 148 |
+
"encoder.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 149 |
+
"encoder.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 150 |
+
"encoder.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 151 |
+
"encoder.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 152 |
+
"encoder.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 153 |
+
"encoder.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 154 |
+
"encoder.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 155 |
+
"encoder.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 156 |
+
"encoder.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 157 |
+
"encoder.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 158 |
+
"encoder.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 159 |
+
"encoder.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 160 |
+
"encoder.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 161 |
+
"encoder.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 162 |
+
"encoder.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 163 |
+
"encoder.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 164 |
+
"encoder.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 165 |
+
"encoder.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 166 |
+
"encoder.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 167 |
+
"encoder.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 168 |
+
"encoder.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 169 |
+
"encoder.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 170 |
+
"encoder.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 171 |
+
"encoder.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 172 |
+
"encoder.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 173 |
+
"encoder.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 174 |
+
"encoder.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 175 |
+
"encoder.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 176 |
+
"encoder.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 177 |
+
"encoder.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 178 |
+
"encoder.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 179 |
+
"encoder.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 180 |
+
"encoder.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 181 |
+
"encoder.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 182 |
+
"encoder.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 183 |
+
"encoder.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 184 |
+
"encoder.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 185 |
+
"encoder.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 186 |
+
"encoder.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 187 |
+
"encoder.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 188 |
+
"encoder.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 189 |
+
"encoder.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 190 |
+
"encoder.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 191 |
+
"encoder.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 192 |
+
"encoder.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 193 |
+
"encoder.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 194 |
+
"encoder.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 195 |
+
"encoder.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 196 |
+
"encoder.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 197 |
+
"encoder.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 198 |
+
"encoder.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 199 |
+
"encoder.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 200 |
+
"encoder.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 201 |
+
"encoder.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 202 |
+
"encoder.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 203 |
+
"encoder.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 204 |
+
"encoder.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 205 |
+
"encoder.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 206 |
+
"encoder.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 207 |
+
"encoder.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 208 |
+
"encoder.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 209 |
+
"encoder.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 210 |
+
"encoder.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 211 |
+
"encoder.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 212 |
+
"encoder.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 213 |
+
"encoder.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 214 |
+
"encoder.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 215 |
+
"encoder.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 216 |
+
"encoder.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 217 |
+
"encoder.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 218 |
+
"encoder.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 219 |
+
"encoder.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 220 |
+
"encoder.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 221 |
+
"encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 222 |
+
"encoder.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 223 |
+
"encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 224 |
+
"encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 225 |
+
"encoder.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 226 |
+
"encoder.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 227 |
+
"encoder.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 228 |
+
"encoder.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 229 |
+
"encoder.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 230 |
+
"encoder.layers.30.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 231 |
+
"encoder.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 232 |
+
"encoder.layers.30.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 233 |
+
"encoder.layers.30.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 234 |
+
"encoder.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 235 |
+
"encoder.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 236 |
+
"encoder.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 237 |
+
"encoder.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 238 |
+
"encoder.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 239 |
+
"encoder.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 240 |
+
"encoder.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 241 |
+
"encoder.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 242 |
+
"encoder.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 243 |
+
"encoder.layers.32.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 244 |
+
"encoder.layers.32.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 245 |
+
"encoder.layers.32.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 246 |
+
"encoder.layers.32.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 247 |
+
"encoder.layers.32.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 248 |
+
"encoder.layers.32.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 249 |
+
"encoder.layers.32.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 250 |
+
"encoder.layers.32.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 251 |
+
"encoder.layers.32.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 252 |
+
"encoder.layers.33.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 253 |
+
"encoder.layers.33.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 254 |
+
"encoder.layers.33.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 255 |
+
"encoder.layers.33.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 256 |
+
"encoder.layers.33.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 257 |
+
"encoder.layers.33.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 258 |
+
"encoder.layers.33.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 259 |
+
"encoder.layers.33.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 260 |
+
"encoder.layers.33.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 261 |
+
"encoder.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 262 |
+
"encoder.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 263 |
+
"encoder.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 264 |
+
"encoder.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 265 |
+
"encoder.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 266 |
+
"encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 267 |
+
"encoder.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 268 |
+
"encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 269 |
+
"encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 270 |
+
"encoder.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 271 |
+
"encoder.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 272 |
+
"encoder.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 273 |
+
"encoder.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 274 |
+
"encoder.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 275 |
+
"encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 276 |
+
"encoder.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 277 |
+
"encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 278 |
+
"encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 279 |
+
"encoder.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 280 |
+
"encoder.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 281 |
+
"encoder.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 282 |
+
"encoder.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 283 |
+
"encoder.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 284 |
+
"encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 285 |
+
"encoder.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 286 |
+
"encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 287 |
+
"encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 288 |
+
"encoder.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 289 |
+
"encoder.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 290 |
+
"encoder.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 291 |
+
"encoder.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 292 |
+
"encoder.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 293 |
+
"encoder.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 294 |
+
"encoder.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 295 |
+
"encoder.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 296 |
+
"encoder.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 297 |
+
"encoder.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 298 |
+
"encoder.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 299 |
+
"encoder.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 300 |
+
"encoder.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 301 |
+
"encoder.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 302 |
+
"encoder.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 303 |
+
"encoder.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 304 |
+
"encoder.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 305 |
+
"encoder.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 306 |
+
"encoder.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 307 |
+
"encoder.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 308 |
+
"encoder.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 309 |
+
"encoder.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 310 |
+
"encoder.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 311 |
+
"encoder.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 312 |
+
"encoder.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 313 |
+
"encoder.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 314 |
+
"encoder.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 315 |
+
"encoder.multi_modal_projector.linear_1.weight": "model-00001-of-00004.safetensors",
|
| 316 |
+
"encoder.multi_modal_projector.linear_2.weight": "model-00001-of-00004.safetensors",
|
| 317 |
+
"encoder.multi_modal_projector.norm.weight": "model-00001-of-00004.safetensors",
|
| 318 |
+
"encoder.multi_modal_projector.patch_merger.merging_layer.weight": "model-00001-of-00004.safetensors",
|
| 319 |
+
"encoder.norm.weight": "model-00004-of-00004.safetensors",
|
| 320 |
+
"encoder.vision_tower.ln_pre.weight": "model-00001-of-00004.safetensors",
|
| 321 |
+
"encoder.vision_tower.patch_conv.weight": "model-00001-of-00004.safetensors",
|
| 322 |
+
"encoder.vision_tower.transformer.layers.0.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 323 |
+
"encoder.vision_tower.transformer.layers.0.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 324 |
+
"encoder.vision_tower.transformer.layers.0.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 325 |
+
"encoder.vision_tower.transformer.layers.0.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 326 |
+
"encoder.vision_tower.transformer.layers.0.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 327 |
+
"encoder.vision_tower.transformer.layers.0.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 328 |
+
"encoder.vision_tower.transformer.layers.0.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 329 |
+
"encoder.vision_tower.transformer.layers.0.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 330 |
+
"encoder.vision_tower.transformer.layers.0.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 331 |
+
"encoder.vision_tower.transformer.layers.1.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 332 |
+
"encoder.vision_tower.transformer.layers.1.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 333 |
+
"encoder.vision_tower.transformer.layers.1.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 334 |
+
"encoder.vision_tower.transformer.layers.1.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 335 |
+
"encoder.vision_tower.transformer.layers.1.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 336 |
+
"encoder.vision_tower.transformer.layers.1.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 337 |
+
"encoder.vision_tower.transformer.layers.1.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 338 |
+
"encoder.vision_tower.transformer.layers.1.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 339 |
+
"encoder.vision_tower.transformer.layers.1.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 340 |
+
"encoder.vision_tower.transformer.layers.10.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 341 |
+
"encoder.vision_tower.transformer.layers.10.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 342 |
+
"encoder.vision_tower.transformer.layers.10.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 343 |
+
"encoder.vision_tower.transformer.layers.10.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 344 |
+
"encoder.vision_tower.transformer.layers.10.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 345 |
+
"encoder.vision_tower.transformer.layers.10.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 346 |
+
"encoder.vision_tower.transformer.layers.10.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 347 |
+
"encoder.vision_tower.transformer.layers.10.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 348 |
+
"encoder.vision_tower.transformer.layers.10.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 349 |
+
"encoder.vision_tower.transformer.layers.11.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 350 |
+
"encoder.vision_tower.transformer.layers.11.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 351 |
+
"encoder.vision_tower.transformer.layers.11.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 352 |
+
"encoder.vision_tower.transformer.layers.11.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 353 |
+
"encoder.vision_tower.transformer.layers.11.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 354 |
+
"encoder.vision_tower.transformer.layers.11.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 355 |
+
"encoder.vision_tower.transformer.layers.11.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 356 |
+
"encoder.vision_tower.transformer.layers.11.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 357 |
+
"encoder.vision_tower.transformer.layers.11.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 358 |
+
"encoder.vision_tower.transformer.layers.12.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 359 |
+
"encoder.vision_tower.transformer.layers.12.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 360 |
+
"encoder.vision_tower.transformer.layers.12.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 361 |
+
"encoder.vision_tower.transformer.layers.12.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 362 |
+
"encoder.vision_tower.transformer.layers.12.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 363 |
+
"encoder.vision_tower.transformer.layers.12.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 364 |
+
"encoder.vision_tower.transformer.layers.12.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 365 |
+
"encoder.vision_tower.transformer.layers.12.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 366 |
+
"encoder.vision_tower.transformer.layers.12.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 367 |
+
"encoder.vision_tower.transformer.layers.13.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 368 |
+
"encoder.vision_tower.transformer.layers.13.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 369 |
+
"encoder.vision_tower.transformer.layers.13.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 370 |
+
"encoder.vision_tower.transformer.layers.13.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 371 |
+
"encoder.vision_tower.transformer.layers.13.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 372 |
+
"encoder.vision_tower.transformer.layers.13.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 373 |
+
"encoder.vision_tower.transformer.layers.13.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 374 |
+
"encoder.vision_tower.transformer.layers.13.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 375 |
+
"encoder.vision_tower.transformer.layers.13.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 376 |
+
"encoder.vision_tower.transformer.layers.14.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 377 |
+
"encoder.vision_tower.transformer.layers.14.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 378 |
+
"encoder.vision_tower.transformer.layers.14.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 379 |
+
"encoder.vision_tower.transformer.layers.14.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 380 |
+
"encoder.vision_tower.transformer.layers.14.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 381 |
+
"encoder.vision_tower.transformer.layers.14.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 382 |
+
"encoder.vision_tower.transformer.layers.14.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 383 |
+
"encoder.vision_tower.transformer.layers.14.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 384 |
+
"encoder.vision_tower.transformer.layers.14.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 385 |
+
"encoder.vision_tower.transformer.layers.15.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 386 |
+
"encoder.vision_tower.transformer.layers.15.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 387 |
+
"encoder.vision_tower.transformer.layers.15.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 388 |
+
"encoder.vision_tower.transformer.layers.15.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 389 |
+
"encoder.vision_tower.transformer.layers.15.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 390 |
+
"encoder.vision_tower.transformer.layers.15.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 391 |
+
"encoder.vision_tower.transformer.layers.15.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 392 |
+
"encoder.vision_tower.transformer.layers.15.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 393 |
+
"encoder.vision_tower.transformer.layers.15.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 394 |
+
"encoder.vision_tower.transformer.layers.16.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 395 |
+
"encoder.vision_tower.transformer.layers.16.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 396 |
+
"encoder.vision_tower.transformer.layers.16.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 397 |
+
"encoder.vision_tower.transformer.layers.16.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 398 |
+
"encoder.vision_tower.transformer.layers.16.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 399 |
+
"encoder.vision_tower.transformer.layers.16.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 400 |
+
"encoder.vision_tower.transformer.layers.16.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 401 |
+
"encoder.vision_tower.transformer.layers.16.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 402 |
+
"encoder.vision_tower.transformer.layers.16.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 403 |
+
"encoder.vision_tower.transformer.layers.17.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 404 |
+
"encoder.vision_tower.transformer.layers.17.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 405 |
+
"encoder.vision_tower.transformer.layers.17.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 406 |
+
"encoder.vision_tower.transformer.layers.17.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 407 |
+
"encoder.vision_tower.transformer.layers.17.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 408 |
+
"encoder.vision_tower.transformer.layers.17.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 409 |
+
"encoder.vision_tower.transformer.layers.17.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 410 |
+
"encoder.vision_tower.transformer.layers.17.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 411 |
+
"encoder.vision_tower.transformer.layers.17.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 412 |
+
"encoder.vision_tower.transformer.layers.18.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 413 |
+
"encoder.vision_tower.transformer.layers.18.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 414 |
+
"encoder.vision_tower.transformer.layers.18.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 415 |
+
"encoder.vision_tower.transformer.layers.18.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 416 |
+
"encoder.vision_tower.transformer.layers.18.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 417 |
+
"encoder.vision_tower.transformer.layers.18.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 418 |
+
"encoder.vision_tower.transformer.layers.18.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 419 |
+
"encoder.vision_tower.transformer.layers.18.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 420 |
+
"encoder.vision_tower.transformer.layers.18.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 421 |
+
"encoder.vision_tower.transformer.layers.19.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 422 |
+
"encoder.vision_tower.transformer.layers.19.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 423 |
+
"encoder.vision_tower.transformer.layers.19.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 424 |
+
"encoder.vision_tower.transformer.layers.19.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 425 |
+
"encoder.vision_tower.transformer.layers.19.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 426 |
+
"encoder.vision_tower.transformer.layers.19.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 427 |
+
"encoder.vision_tower.transformer.layers.19.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 428 |
+
"encoder.vision_tower.transformer.layers.19.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 429 |
+
"encoder.vision_tower.transformer.layers.19.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 430 |
+
"encoder.vision_tower.transformer.layers.2.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 431 |
+
"encoder.vision_tower.transformer.layers.2.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 432 |
+
"encoder.vision_tower.transformer.layers.2.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 433 |
+
"encoder.vision_tower.transformer.layers.2.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 434 |
+
"encoder.vision_tower.transformer.layers.2.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 435 |
+
"encoder.vision_tower.transformer.layers.2.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 436 |
+
"encoder.vision_tower.transformer.layers.2.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 437 |
+
"encoder.vision_tower.transformer.layers.2.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 438 |
+
"encoder.vision_tower.transformer.layers.2.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 439 |
+
"encoder.vision_tower.transformer.layers.20.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 440 |
+
"encoder.vision_tower.transformer.layers.20.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 441 |
+
"encoder.vision_tower.transformer.layers.20.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 442 |
+
"encoder.vision_tower.transformer.layers.20.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 443 |
+
"encoder.vision_tower.transformer.layers.20.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 444 |
+
"encoder.vision_tower.transformer.layers.20.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 445 |
+
"encoder.vision_tower.transformer.layers.20.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 446 |
+
"encoder.vision_tower.transformer.layers.20.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 447 |
+
"encoder.vision_tower.transformer.layers.20.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 448 |
+
"encoder.vision_tower.transformer.layers.21.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 449 |
+
"encoder.vision_tower.transformer.layers.21.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 450 |
+
"encoder.vision_tower.transformer.layers.21.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 451 |
+
"encoder.vision_tower.transformer.layers.21.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 452 |
+
"encoder.vision_tower.transformer.layers.21.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 453 |
+
"encoder.vision_tower.transformer.layers.21.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 454 |
+
"encoder.vision_tower.transformer.layers.21.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 455 |
+
"encoder.vision_tower.transformer.layers.21.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 456 |
+
"encoder.vision_tower.transformer.layers.21.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 457 |
+
"encoder.vision_tower.transformer.layers.22.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 458 |
+
"encoder.vision_tower.transformer.layers.22.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 459 |
+
"encoder.vision_tower.transformer.layers.22.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 460 |
+
"encoder.vision_tower.transformer.layers.22.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 461 |
+
"encoder.vision_tower.transformer.layers.22.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 462 |
+
"encoder.vision_tower.transformer.layers.22.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 463 |
+
"encoder.vision_tower.transformer.layers.22.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 464 |
+
"encoder.vision_tower.transformer.layers.22.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 465 |
+
"encoder.vision_tower.transformer.layers.22.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 466 |
+
"encoder.vision_tower.transformer.layers.23.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 467 |
+
"encoder.vision_tower.transformer.layers.23.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 468 |
+
"encoder.vision_tower.transformer.layers.23.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 469 |
+
"encoder.vision_tower.transformer.layers.23.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 470 |
+
"encoder.vision_tower.transformer.layers.23.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 471 |
+
"encoder.vision_tower.transformer.layers.23.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 472 |
+
"encoder.vision_tower.transformer.layers.23.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 473 |
+
"encoder.vision_tower.transformer.layers.23.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 474 |
+
"encoder.vision_tower.transformer.layers.23.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 475 |
+
"encoder.vision_tower.transformer.layers.3.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 476 |
+
"encoder.vision_tower.transformer.layers.3.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 477 |
+
"encoder.vision_tower.transformer.layers.3.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 478 |
+
"encoder.vision_tower.transformer.layers.3.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 479 |
+
"encoder.vision_tower.transformer.layers.3.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 480 |
+
"encoder.vision_tower.transformer.layers.3.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 481 |
+
"encoder.vision_tower.transformer.layers.3.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 482 |
+
"encoder.vision_tower.transformer.layers.3.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 483 |
+
"encoder.vision_tower.transformer.layers.3.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 484 |
+
"encoder.vision_tower.transformer.layers.4.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 485 |
+
"encoder.vision_tower.transformer.layers.4.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 486 |
+
"encoder.vision_tower.transformer.layers.4.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 487 |
+
"encoder.vision_tower.transformer.layers.4.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 488 |
+
"encoder.vision_tower.transformer.layers.4.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 489 |
+
"encoder.vision_tower.transformer.layers.4.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 490 |
+
"encoder.vision_tower.transformer.layers.4.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 491 |
+
"encoder.vision_tower.transformer.layers.4.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 492 |
+
"encoder.vision_tower.transformer.layers.4.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 493 |
+
"encoder.vision_tower.transformer.layers.5.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 494 |
+
"encoder.vision_tower.transformer.layers.5.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 495 |
+
"encoder.vision_tower.transformer.layers.5.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 496 |
+
"encoder.vision_tower.transformer.layers.5.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 497 |
+
"encoder.vision_tower.transformer.layers.5.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 498 |
+
"encoder.vision_tower.transformer.layers.5.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 499 |
+
"encoder.vision_tower.transformer.layers.5.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 500 |
+
"encoder.vision_tower.transformer.layers.5.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 501 |
+
"encoder.vision_tower.transformer.layers.5.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 502 |
+
"encoder.vision_tower.transformer.layers.6.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 503 |
+
"encoder.vision_tower.transformer.layers.6.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 504 |
+
"encoder.vision_tower.transformer.layers.6.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 505 |
+
"encoder.vision_tower.transformer.layers.6.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 506 |
+
"encoder.vision_tower.transformer.layers.6.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 507 |
+
"encoder.vision_tower.transformer.layers.6.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 508 |
+
"encoder.vision_tower.transformer.layers.6.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 509 |
+
"encoder.vision_tower.transformer.layers.6.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 510 |
+
"encoder.vision_tower.transformer.layers.6.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 511 |
+
"encoder.vision_tower.transformer.layers.7.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 512 |
+
"encoder.vision_tower.transformer.layers.7.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 513 |
+
"encoder.vision_tower.transformer.layers.7.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 514 |
+
"encoder.vision_tower.transformer.layers.7.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 515 |
+
"encoder.vision_tower.transformer.layers.7.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 516 |
+
"encoder.vision_tower.transformer.layers.7.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 517 |
+
"encoder.vision_tower.transformer.layers.7.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 518 |
+
"encoder.vision_tower.transformer.layers.7.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 519 |
+
"encoder.vision_tower.transformer.layers.7.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 520 |
+
"encoder.vision_tower.transformer.layers.8.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 521 |
+
"encoder.vision_tower.transformer.layers.8.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 522 |
+
"encoder.vision_tower.transformer.layers.8.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 523 |
+
"encoder.vision_tower.transformer.layers.8.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 524 |
+
"encoder.vision_tower.transformer.layers.8.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 525 |
+
"encoder.vision_tower.transformer.layers.8.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 526 |
+
"encoder.vision_tower.transformer.layers.8.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 527 |
+
"encoder.vision_tower.transformer.layers.8.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 528 |
+
"encoder.vision_tower.transformer.layers.8.ffn_norm.weight": "model-00001-of-00004.safetensors",
|
| 529 |
+
"encoder.vision_tower.transformer.layers.9.attention.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 530 |
+
"encoder.vision_tower.transformer.layers.9.attention.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 531 |
+
"encoder.vision_tower.transformer.layers.9.attention.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 532 |
+
"encoder.vision_tower.transformer.layers.9.attention.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 533 |
+
"encoder.vision_tower.transformer.layers.9.attention_norm.weight": "model-00001-of-00004.safetensors",
|
| 534 |
+
"encoder.vision_tower.transformer.layers.9.feed_forward.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 535 |
+
"encoder.vision_tower.transformer.layers.9.feed_forward.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 536 |
+
"encoder.vision_tower.transformer.layers.9.feed_forward.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 537 |
+
"encoder.vision_tower.transformer.layers.9.ffn_norm.weight": "model-00001-of-00004.safetensors"
|
| 538 |
+
}
|
| 539 |
+
}
|
model_cards/bias.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:---------------------------------------------------------------------------------------------------|:---------------
|
| 3 |
+
Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | [None]
|
| 4 |
+
Measures taken to mitigate against unwanted bias: | [None]
|
model_cards/explainability.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------
|
| 3 |
+
Intended Task/Domain: | Text generation
|
| 4 |
+
Model Type: | Transformer
|
| 5 |
+
Intended Users: | Generative AI creators working with conversational AI models.
|
| 6 |
+
Output: | Text (Responds to posed question, Stateful - remembers previous answers)
|
| 7 |
+
Describe how the model works: | Text input is encoded into tokens and passed into a transformer-based language model, which returns a text response.
|
| 8 |
+
Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable
|
| 9 |
+
Technical Limitations & Mitigation: | The model cannot perform long-horizon reasoning and tool calling.
|
| 10 |
+
Verified to have met prescribed NVIDIA quality standards: | Yes
|
| 11 |
+
Performance Metrics: | Accuracy, Latency, Throughput
|
| 12 |
+
Potential Known Risks: | In some instances, the model may think too long and struggle to derive final answers. The model's output can generate all forms of text, including what may be considered toxic, offensive, or indecent.
|
| 13 |
+
Licensing: | nvidia-open-model-license.
|
model_cards/privacy.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:----------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------
|
| 3 |
+
Generatable or reverse engineerable personal data? | [No]
|
| 4 |
+
Personal data used to create this model? | [No]
|
| 5 |
+
Was consent obtained for any personal data used? | [Not Applicable]
|
| 6 |
+
How often is dataset reviewed? | [During dataset creation, model training, evaluation, and the prerelease phase.]
|
| 7 |
+
Was data from user interactions with the AI model (e.g. user input and prompts) used to train the model? | [Yes]
|
| 8 |
+
Is there provenance for all datasets used in training? | Yes
|
| 9 |
+
Does data labeling (annotation, metadata) comply with privacy laws? | Yes
|
| 10 |
+
Is data compliant with data subject requests for data correction or removal, if such a request was made? | Not Applicable.
|
| 11 |
+
Applicable Privacy Policy | https://www.nvidia.com/en-us/about-nvidia/privacy-policy/
|
model_cards/safety.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Field | Response
|
| 2 |
+
:---------------------------------------------------|:----------------------------------
|
| 3 |
+
Model Application Field(s): | [Media & Entertainment].
|
| 4 |
+
Describe the life critical impact (if present). | Not Applicable
|
| 5 |
+
Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to.
|
| 6 |
+
Use Case Restrictions: | Abide by nvidia-open-model-license.
|
modeling_ministral.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from transformers.utils.generic import check_model_inputs
|
| 8 |
+
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 11 |
+
from transformers.generation import GenerationMixin
|
| 12 |
+
# from transformers.integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
| 13 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 14 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask, ALL_MASK_ATTENTION_FUNCTIONS, sdpa_mask_older_torch
|
| 15 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 16 |
+
from transformers.modeling_layers import (
|
| 17 |
+
GenericForQuestionAnswering,
|
| 18 |
+
GenericForSequenceClassification,
|
| 19 |
+
GenericForTokenClassification,
|
| 20 |
+
GradientCheckpointingLayer,
|
| 21 |
+
)
|
| 22 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 23 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 24 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 25 |
+
from transformers.processing_utils import Unpack
|
| 26 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 27 |
+
from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel
|
| 28 |
+
from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig
|
| 29 |
+
# from transformers.utils.generic import maybe_autocast
|
| 30 |
+
from .configuration_nemotron_labs_diffusion_vlm import NemotronLabsDiffusionVLMConfig
|
| 31 |
+
|
| 32 |
+
#ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Ministral3PatchMerger(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
Learned merging of spatial_merge_size ** 2 patches
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, config):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.config = config
|
| 43 |
+
|
| 44 |
+
hidden_size = config.vision_config.hidden_size
|
| 45 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 46 |
+
self.patch_size = self.config.vision_config.patch_size
|
| 47 |
+
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
|
| 48 |
+
|
| 49 |
+
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
image_sizes = [
|
| 51 |
+
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
tokens_per_image = [h * w for h, w in image_sizes]
|
| 55 |
+
d = image_features.shape[-1]
|
| 56 |
+
|
| 57 |
+
permuted_tensor = []
|
| 58 |
+
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
| 59 |
+
# Reshape image_tokens into a 2D grid
|
| 60 |
+
h, w = image_sizes[image_index]
|
| 61 |
+
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
| 62 |
+
grid = torch.nn.functional.unfold(
|
| 63 |
+
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
|
| 64 |
+
)
|
| 65 |
+
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
| 66 |
+
permuted_tensor.append(grid)
|
| 67 |
+
|
| 68 |
+
image_features = torch.cat(permuted_tensor, dim=0)
|
| 69 |
+
image_features = self.merging_layer(image_features)
|
| 70 |
+
return image_features
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Ministral3MultiModalProjector(nn.Module):
|
| 75 |
+
def __init__(self, config):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.norm = Ministral3RMSNorm(config.vision_config.hidden_size, eps=config.rms_norm_eps)
|
| 78 |
+
self.patch_merger = Ministral3PatchMerger(config)
|
| 79 |
+
# We have hidden_size * the number of vision feature layers
|
| 80 |
+
self.num_feature_layers = (
|
| 81 |
+
1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
| 82 |
+
)
|
| 83 |
+
self.linear_1 = nn.Linear(
|
| 84 |
+
config.vision_config.hidden_size * self.num_feature_layers,
|
| 85 |
+
config.hidden_size,
|
| 86 |
+
bias=config.multimodal_projector_bias,
|
| 87 |
+
)
|
| 88 |
+
self.act = ACT2FN[config.projector_hidden_act]
|
| 89 |
+
self.linear_2 = nn.Linear(
|
| 90 |
+
config.hidden_size, config.hidden_size, bias=config.multimodal_projector_bias
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
| 94 |
+
image_features = self.norm(image_features)
|
| 95 |
+
image_features = self.patch_merger(image_features, image_sizes)
|
| 96 |
+
hidden_states = self.linear_1(image_features)
|
| 97 |
+
hidden_states = self.act(hidden_states)
|
| 98 |
+
hidden_states = self.linear_2(hidden_states)
|
| 99 |
+
return hidden_states
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def rotate_half(x):
|
| 103 |
+
"""Rotates half the hidden dims of the input."""
|
| 104 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 105 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 106 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 107 |
+
|
| 108 |
+
# @use_kernel_func_from_hub("rotary_pos_emb")
|
| 109 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 110 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
q (`torch.Tensor`): The query tensor.
|
| 114 |
+
k (`torch.Tensor`): The key tensor.
|
| 115 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 116 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 117 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 118 |
+
Deprecated and unused.
|
| 119 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 120 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 121 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 122 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 123 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 124 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 125 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 126 |
+
Returns:
|
| 127 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 128 |
+
"""
|
| 129 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 130 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 131 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 132 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 133 |
+
return q_embed, k_embed
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 139 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 140 |
+
"""
|
| 141 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 142 |
+
if n_rep == 1:
|
| 143 |
+
return hidden_states
|
| 144 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 145 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def eager_attention_forward(
|
| 149 |
+
module: nn.Module,
|
| 150 |
+
query: torch.Tensor,
|
| 151 |
+
key: torch.Tensor,
|
| 152 |
+
value: torch.Tensor,
|
| 153 |
+
attention_mask: Optional[torch.Tensor],
|
| 154 |
+
scaling: float,
|
| 155 |
+
dropout: float = 0.0,
|
| 156 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 157 |
+
):
|
| 158 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 159 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 160 |
+
|
| 161 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 162 |
+
if attention_mask is not None:
|
| 163 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 164 |
+
attn_weights = attn_weights + causal_mask
|
| 165 |
+
|
| 166 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 167 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 168 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 169 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 170 |
+
|
| 171 |
+
return attn_output, attn_weights
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
|
| 175 |
+
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
|
| 176 |
+
return scaling.unsqueeze(-1)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# @use_kernelized_func(apply_rotary_pos_emb)
|
| 180 |
+
class Ministral3Attention(nn.Module):
|
| 181 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, config: NemotronLabsDiffusionVLMConfig, layer_idx: int):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.config = config
|
| 186 |
+
self.layer_idx = layer_idx
|
| 187 |
+
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 188 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 189 |
+
self.scaling = self.head_dim**-0.5
|
| 190 |
+
self.attention_dropout = config.attention_dropout
|
| 191 |
+
self.is_causal = True
|
| 192 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
| 193 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
| 194 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
| 195 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
| 196 |
+
|
| 197 |
+
self.diffusion_lm = config.diffusion_lm
|
| 198 |
+
|
| 199 |
+
def forward(
|
| 200 |
+
self,
|
| 201 |
+
hidden_states: torch.Tensor,
|
| 202 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 203 |
+
attention_mask: Optional[torch.Tensor],
|
| 204 |
+
past_key_values: Optional[Cache] = None,
|
| 205 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 206 |
+
use_cache: Optional[bool] = False,
|
| 207 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 208 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 209 |
+
input_shape = hidden_states.shape[:-1]
|
| 210 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 211 |
+
|
| 212 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 213 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 214 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 215 |
+
|
| 216 |
+
cos, sin = position_embeddings
|
| 217 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 218 |
+
query_states = query_states * _get_llama_4_attn_scale(
|
| 219 |
+
cache_position,
|
| 220 |
+
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
| 221 |
+
self.config.rope_parameters.get("original_max_position_embeddings"),
|
| 222 |
+
).to(query_states.dtype)
|
| 223 |
+
|
| 224 |
+
if past_key_values is not None:
|
| 225 |
+
if use_cache:
|
| 226 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 227 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 228 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 229 |
+
else: ## if use_cache == False, do not update cache
|
| 230 |
+
old_k, old_v = past_key_values.layers[self.layer_idx].keys, past_key_values.layers[self.layer_idx].values
|
| 231 |
+
key_states = torch.cat([old_k, key_states], dim=-2)
|
| 232 |
+
value_states = torch.cat([old_v, value_states], dim=-2)
|
| 233 |
+
|
| 234 |
+
attention_interface: Callable = eager_attention_forward
|
| 235 |
+
if self.config._attn_implementation != "eager":
|
| 236 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 237 |
+
|
| 238 |
+
if self.diffusion_lm:
|
| 239 |
+
attn_output, attn_weights = attention_interface(
|
| 240 |
+
self,
|
| 241 |
+
query_states,
|
| 242 |
+
key_states,
|
| 243 |
+
value_states,
|
| 244 |
+
None,
|
| 245 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 246 |
+
scaling=self.scaling,
|
| 247 |
+
is_causal=False,
|
| 248 |
+
**kwargs,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
else:
|
| 252 |
+
attn_output, attn_weights = attention_interface(
|
| 253 |
+
self,
|
| 254 |
+
query_states,
|
| 255 |
+
key_states,
|
| 256 |
+
value_states,
|
| 257 |
+
attention_mask,
|
| 258 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 259 |
+
scaling=self.scaling,
|
| 260 |
+
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
|
| 261 |
+
**kwargs,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 265 |
+
attn_output = self.o_proj(attn_output)
|
| 266 |
+
return attn_output, attn_weights
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class Ministral3MLP(nn.Module):
|
| 270 |
+
def __init__(self, config):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.config = config
|
| 273 |
+
self.hidden_size = config.hidden_size
|
| 274 |
+
self.intermediate_size = config.intermediate_size
|
| 275 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 276 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 277 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 278 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 279 |
+
|
| 280 |
+
def forward(self, x):
|
| 281 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 282 |
+
return down_proj
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 286 |
+
class Ministral3RMSNorm(nn.Module):
|
| 287 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 288 |
+
"""
|
| 289 |
+
Ministral3RMSNorm is equivalent to T5LayerNorm
|
| 290 |
+
"""
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 293 |
+
self.variance_epsilon = eps
|
| 294 |
+
|
| 295 |
+
def forward(self, hidden_states):
|
| 296 |
+
input_dtype = hidden_states.dtype
|
| 297 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 298 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 299 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 300 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 301 |
+
|
| 302 |
+
def extra_repr(self):
|
| 303 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class Ministral3DecoderLayer(GradientCheckpointingLayer):
|
| 307 |
+
def __init__(self, config: NemotronLabsDiffusionVLMConfig, layer_idx: int):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.hidden_size = config.hidden_size
|
| 310 |
+
|
| 311 |
+
if hasattr(config, 'attn_class'):
|
| 312 |
+
attn_class = config.attn_class
|
| 313 |
+
else:
|
| 314 |
+
attn_class = Ministral3Attention
|
| 315 |
+
|
| 316 |
+
self.self_attn = attn_class(config=config, layer_idx=layer_idx)
|
| 317 |
+
self.mlp = Ministral3MLP(config)
|
| 318 |
+
self.input_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 319 |
+
self.post_attention_layernorm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 320 |
+
|
| 321 |
+
def forward(
|
| 322 |
+
self,
|
| 323 |
+
hidden_states: torch.Tensor,
|
| 324 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 325 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 326 |
+
past_key_values: Optional[Cache] = None,
|
| 327 |
+
use_cache: Optional[bool] = False,
|
| 328 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 329 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 330 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 331 |
+
) -> torch.Tensor:
|
| 332 |
+
residual = hidden_states
|
| 333 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 334 |
+
# Self Attention
|
| 335 |
+
hidden_states, _ = self.self_attn(
|
| 336 |
+
hidden_states=hidden_states,
|
| 337 |
+
attention_mask=attention_mask,
|
| 338 |
+
position_ids=position_ids,
|
| 339 |
+
past_key_values=past_key_values,
|
| 340 |
+
use_cache=use_cache,
|
| 341 |
+
cache_position=cache_position,
|
| 342 |
+
position_embeddings=position_embeddings,
|
| 343 |
+
**kwargs,
|
| 344 |
+
)
|
| 345 |
+
hidden_states = residual + hidden_states
|
| 346 |
+
|
| 347 |
+
# Fully Connected
|
| 348 |
+
residual = hidden_states
|
| 349 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 350 |
+
hidden_states = self.mlp(hidden_states)
|
| 351 |
+
hidden_states = residual + hidden_states
|
| 352 |
+
return hidden_states
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@auto_docstring
|
| 356 |
+
class Ministral3PreTrainedModel(PreTrainedModel):
|
| 357 |
+
config: NemotronLabsDiffusionVLMConfig
|
| 358 |
+
base_model_prefix = "model"
|
| 359 |
+
supports_gradient_checkpointing = True
|
| 360 |
+
# Ministral3RMSNorm must be a separate FSDP unit to avoid weight sharded to size 0 on some ranks
|
| 361 |
+
_no_split_modules = ["Ministral3DecoderLayer", "Ministral3RMSNorm"]
|
| 362 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 363 |
+
_supports_flash_attn = True
|
| 364 |
+
_supports_sdpa = True
|
| 365 |
+
_supports_flex_attn = True
|
| 366 |
+
|
| 367 |
+
_can_compile_fullgraph = True
|
| 368 |
+
_supports_attention_backend = True
|
| 369 |
+
_can_record_outputs = {
|
| 370 |
+
"hidden_states": Ministral3DecoderLayer,
|
| 371 |
+
"attentions": Ministral3Attention,
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class Ministral3RotaryEmbedding(nn.Module):
|
| 376 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 377 |
+
|
| 378 |
+
def __init__(self, config: NemotronLabsDiffusionVLMConfig, device=None):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 381 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 382 |
+
|
| 383 |
+
self.config = config
|
| 384 |
+
|
| 385 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 386 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 387 |
+
if self.rope_type != "default":
|
| 388 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 389 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 390 |
+
|
| 391 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 392 |
+
self.original_inv_freq = inv_freq
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def compute_default_rope_parameters(
|
| 397 |
+
config: Optional[NemotronLabsDiffusionVLMConfig] = None,
|
| 398 |
+
device: Optional["torch.device"] = None,
|
| 399 |
+
seq_len: Optional[int] = None,
|
| 400 |
+
) -> tuple["torch.Tensor", float]:
|
| 401 |
+
"""
|
| 402 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 403 |
+
Args:
|
| 404 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 405 |
+
The model configuration.
|
| 406 |
+
device (`torch.device`):
|
| 407 |
+
The device to use for initialization of the inverse frequencies.
|
| 408 |
+
seq_len (`int`, *optional*):
|
| 409 |
+
The current sequence length. Unused for this type of RoPE.
|
| 410 |
+
Returns:
|
| 411 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 412 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 413 |
+
"""
|
| 414 |
+
base = config.rope_parameters["rope_theta"]
|
| 415 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 416 |
+
|
| 417 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 418 |
+
|
| 419 |
+
# Compute the inverse frequencies
|
| 420 |
+
inv_freq = 1.0 / (
|
| 421 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 422 |
+
)
|
| 423 |
+
return inv_freq, attention_factor
|
| 424 |
+
|
| 425 |
+
@torch.no_grad()
|
| 426 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 427 |
+
def forward(self, x, position_ids):
|
| 428 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 429 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 430 |
+
|
| 431 |
+
# device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 432 |
+
# with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 433 |
+
|
| 434 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 435 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 436 |
+
cos = emb.cos() * self.attention_scaling
|
| 437 |
+
sin = emb.sin() * self.attention_scaling
|
| 438 |
+
|
| 439 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@auto_docstring
|
| 443 |
+
class Ministral3Model(Ministral3PreTrainedModel):
|
| 444 |
+
def __init__(self, config: NemotronLabsDiffusionVLMConfig):
|
| 445 |
+
super().__init__(config)
|
| 446 |
+
vision_config = config.vision_config
|
| 447 |
+
if not isinstance(vision_config, PixtralVisionConfig):
|
| 448 |
+
vision_config = PixtralVisionConfig(**vision_config) if isinstance(vision_config, dict) else PixtralVisionConfig(**vars(vision_config))
|
| 449 |
+
config.vision_config = vision_config
|
| 450 |
+
|
| 451 |
+
self.vision_tower = PixtralVisionModel(vision_config)
|
| 452 |
+
self.multi_modal_projector = Ministral3MultiModalProjector(config)
|
| 453 |
+
self.padding_idx = config.pad_token_id
|
| 454 |
+
self.vocab_size = config.vocab_size
|
| 455 |
+
|
| 456 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 457 |
+
self.layers = nn.ModuleList(
|
| 458 |
+
[Ministral3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 459 |
+
)
|
| 460 |
+
self.norm = Ministral3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 461 |
+
self.rotary_emb = Ministral3RotaryEmbedding(config=config)
|
| 462 |
+
self.gradient_checkpointing = False
|
| 463 |
+
|
| 464 |
+
# Initialize weights and apply final processing
|
| 465 |
+
self.post_init()
|
| 466 |
+
|
| 467 |
+
@check_model_inputs
|
| 468 |
+
@auto_docstring
|
| 469 |
+
def forward(
|
| 470 |
+
self,
|
| 471 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 472 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 473 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 474 |
+
past_key_values: Optional[Cache] = None,
|
| 475 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 476 |
+
use_cache: Optional[bool] = None,
|
| 477 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 478 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 479 |
+
) -> BaseModelOutputWithPast:
|
| 480 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 481 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 482 |
+
|
| 483 |
+
if inputs_embeds is None:
|
| 484 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 485 |
+
|
| 486 |
+
if use_cache and past_key_values is None:
|
| 487 |
+
# past_key_values = DynamicCache(config=self.config)
|
| 488 |
+
past_key_values = DynamicCache()
|
| 489 |
+
|
| 490 |
+
if cache_position is None:
|
| 491 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 492 |
+
cache_position = torch.arange(
|
| 493 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
if position_ids is None:
|
| 497 |
+
position_ids = cache_position.unsqueeze(0)
|
| 498 |
+
|
| 499 |
+
if kwargs.get("use_causal_mask", False):
|
| 500 |
+
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
|
| 501 |
+
causal_mask = mask_function(
|
| 502 |
+
config=self.config,
|
| 503 |
+
input_embeds=inputs_embeds,
|
| 504 |
+
attention_mask=attention_mask,
|
| 505 |
+
cache_position=cache_position,
|
| 506 |
+
past_key_values=past_key_values,
|
| 507 |
+
position_ids=position_ids,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
else:
|
| 511 |
+
causal_mask = None
|
| 512 |
+
|
| 513 |
+
hidden_states = inputs_embeds
|
| 514 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
| 515 |
+
|
| 516 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 517 |
+
hidden_states = decoder_layer(
|
| 518 |
+
hidden_states,
|
| 519 |
+
attention_mask=causal_mask,
|
| 520 |
+
position_ids=position_ids,
|
| 521 |
+
past_key_values=past_key_values,
|
| 522 |
+
use_cache=use_cache,
|
| 523 |
+
cache_position=cache_position,
|
| 524 |
+
position_embeddings=position_embeddings,
|
| 525 |
+
**kwargs,
|
| 526 |
+
)
|
| 527 |
+
hidden_states = self.norm(hidden_states)
|
| 528 |
+
return BaseModelOutputWithPast(
|
| 529 |
+
last_hidden_state=hidden_states,
|
| 530 |
+
past_key_values=past_key_values if use_cache else None,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@auto_docstring
|
| 535 |
+
class Ministral3ForCausalLM(Ministral3PreTrainedModel, GenerationMixin):
|
| 536 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 537 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 538 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 539 |
+
|
| 540 |
+
def __init__(self, config):
|
| 541 |
+
super().__init__(config)
|
| 542 |
+
self.model = Ministral3Model(config)
|
| 543 |
+
self.vocab_size = config.vocab_size
|
| 544 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 545 |
+
|
| 546 |
+
# Initialize weights and apply final processing
|
| 547 |
+
self.post_init()
|
| 548 |
+
|
| 549 |
+
@can_return_tuple
|
| 550 |
+
@auto_docstring
|
| 551 |
+
def forward(
|
| 552 |
+
self,
|
| 553 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 554 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 555 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 556 |
+
past_key_values: Optional[Cache] = None,
|
| 557 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 558 |
+
labels: Optional[torch.LongTensor] = None,
|
| 559 |
+
use_cache: Optional[bool] = None,
|
| 560 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 561 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 562 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 563 |
+
) -> CausalLMOutputWithPast:
|
| 564 |
+
r"""
|
| 565 |
+
Example:
|
| 566 |
+
|
| 567 |
+
```python
|
| 568 |
+
>>> from transformers import AutoTokenizer, Ministral3ForCausalLM
|
| 569 |
+
|
| 570 |
+
>>> model = Ministral3ForCausalLM.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
|
| 571 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-ministral3/Ministral3-2-7b-hf")
|
| 572 |
+
|
| 573 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 574 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 575 |
+
|
| 576 |
+
>>> # Generate
|
| 577 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 578 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 579 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 580 |
+
```"""
|
| 581 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 582 |
+
input_ids=input_ids,
|
| 583 |
+
attention_mask=attention_mask,
|
| 584 |
+
position_ids=position_ids,
|
| 585 |
+
past_key_values=past_key_values,
|
| 586 |
+
inputs_embeds=inputs_embeds,
|
| 587 |
+
use_cache=use_cache,
|
| 588 |
+
cache_position=cache_position,
|
| 589 |
+
**kwargs,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
hidden_states = outputs.last_hidden_state
|
| 593 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 594 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 595 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 596 |
+
|
| 597 |
+
loss = None
|
| 598 |
+
if labels is not None:
|
| 599 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 600 |
+
|
| 601 |
+
return CausalLMOutputWithPast(
|
| 602 |
+
loss=loss,
|
| 603 |
+
logits=logits,
|
| 604 |
+
past_key_values=outputs.past_key_values,
|
| 605 |
+
hidden_states=outputs.hidden_states,
|
| 606 |
+
attentions=outputs.attentions,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
|
| 611 |
+
pass
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
|
| 615 |
+
pass
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
|
| 619 |
+
pass
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
__all__ = [
|
| 623 |
+
"Ministral3ForCausalLM",
|
| 624 |
+
"Ministral3ForQuestionAnswering",
|
| 625 |
+
"Ministral3Model",
|
| 626 |
+
"Ministral3PreTrainedModel",
|
| 627 |
+
"Ministral3ForSequenceClassification",
|
| 628 |
+
"Ministral3ForTokenClassification",
|
| 629 |
+
]
|
modeling_nemotron_labs_diffusion_vlm.py
ADDED
|
@@ -0,0 +1,1378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Optional, Tuple, Union
|
| 4 |
+
import random
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn
|
| 13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
|
| 14 |
+
from transformers.utils import ModelOutput
|
| 15 |
+
|
| 16 |
+
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, or_masks
|
| 17 |
+
|
| 18 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 19 |
+
|
| 20 |
+
from transformers.processing_utils import Unpack
|
| 21 |
+
|
| 22 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 23 |
+
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from transformers.generation import GenerationMixin
|
| 27 |
+
from transformers.loss.loss_utils import LOSS_MAPPING
|
| 28 |
+
|
| 29 |
+
import math
|
| 30 |
+
|
| 31 |
+
from .chat_utils import generate_with_prefix_cache_block_diff
|
| 32 |
+
from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
|
| 33 |
+
from .configuration_nemotron_labs_diffusion_vlm import NemotronLabsDiffusionVLMConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class NemotronLabsDiffusionVLMOutputWithPast(ModelOutput):
|
| 38 |
+
loss: torch.FloatTensor | None = None
|
| 39 |
+
logits: torch.FloatTensor | None = None
|
| 40 |
+
causal_logits: torch.FloatTensor | None = None
|
| 41 |
+
past_key_values: Cache | None = None
|
| 42 |
+
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
| 43 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# @torch.compile(dynamic=True, mode="reduce-overhead")
|
| 47 |
+
# @torch.compile(mode="default")
|
| 48 |
+
# @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
|
| 49 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False)
|
| 50 |
+
def fused_flex_attention(q, k, v, block_mask=None):
|
| 51 |
+
return flex_attention(q, k, v, block_mask=block_mask)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _crop_dynamic_cache(past_key_values: DynamicCache, max_length: int):
|
| 55 |
+
"""Crop a DynamicCache to max_length, compatible with both old and new transformers."""
|
| 56 |
+
if hasattr(past_key_values, 'crop'):
|
| 57 |
+
past_key_values.crop(max_length)
|
| 58 |
+
else:
|
| 59 |
+
for layer_idx in range(len(past_key_values)):
|
| 60 |
+
past_key_values.key_cache[layer_idx] = past_key_values.key_cache[layer_idx][:, :, :max_length]
|
| 61 |
+
past_key_values.value_cache[layer_idx] = past_key_values.value_cache[layer_idx][:, :, :max_length]
|
| 62 |
+
past_key_values._seen_tokens = max_length
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block_length: int):
|
| 66 |
+
"""After quadratic decoding, extract only draft tokens (first of each block) from cache."""
|
| 67 |
+
for layer_idx in range(len(past_key_values)):
|
| 68 |
+
if hasattr(past_key_values, 'layers'):
|
| 69 |
+
layer_cache = past_key_values.layers[layer_idx]
|
| 70 |
+
k, v = layer_cache.keys, layer_cache.values
|
| 71 |
+
else:
|
| 72 |
+
k = past_key_values.key_cache[layer_idx]
|
| 73 |
+
v = past_key_values.value_cache[layer_idx]
|
| 74 |
+
|
| 75 |
+
clean_k, draft_k = k[:, :, :clean_len], k[:, :, clean_len::block_length + 1]
|
| 76 |
+
clean_v, draft_v = v[:, :, :clean_len], v[:, :, clean_len::block_length + 1]
|
| 77 |
+
new_k = torch.cat([clean_k, draft_k], dim=2)
|
| 78 |
+
new_v = torch.cat([clean_v, draft_v], dim=2)
|
| 79 |
+
|
| 80 |
+
if hasattr(past_key_values, 'layers'):
|
| 81 |
+
layer_cache.keys = new_k
|
| 82 |
+
layer_cache.values = new_v
|
| 83 |
+
else:
|
| 84 |
+
past_key_values.key_cache[layer_idx] = new_k
|
| 85 |
+
past_key_values.value_cache[layer_idx] = new_v
|
| 86 |
+
|
| 87 |
+
past_key_values._seen_tokens = clean_len + block_length
|
| 88 |
+
|
| 89 |
+
# with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
|
| 90 |
+
class NemotronLabsDiffusionVLMFlexAttention(Ministral3Attention):
|
| 91 |
+
def __init__(self, *args, **kwargs):
|
| 92 |
+
super().__init__(*args, **kwargs)
|
| 93 |
+
|
| 94 |
+
self.max_seq_length = getattr(self.config, 'max_seq_length', 4096)
|
| 95 |
+
self.block_size_orig = self.config.block_size
|
| 96 |
+
|
| 97 |
+
if self.config.dlm_paradigm == 'bidirectional':
|
| 98 |
+
self.bidirectional_mask = self.compute_block_mask(mode='bidirectional')
|
| 99 |
+
elif self.config.dlm_paradigm == 'autoregressive':
|
| 100 |
+
self.autoregressive_mask = self.compute_block_mask(mode='autoregressive')
|
| 101 |
+
elif self.config.dlm_paradigm == 'block_diff':
|
| 102 |
+
self.block_diff_mask = None
|
| 103 |
+
elif self.config.dlm_paradigm == 'sbd_block_diff':
|
| 104 |
+
self.sbd_block_diff_mask = None
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unknown attention mode: {self.config.dlm_paradigm}")
|
| 107 |
+
|
| 108 |
+
self.block_size = self.block_size_orig
|
| 109 |
+
self.mode = self.config.dlm_paradigm
|
| 110 |
+
self._quadratic_block_mask = {}
|
| 111 |
+
|
| 112 |
+
import torch._dynamo.config as dcfg
|
| 113 |
+
dcfg.cache_size_limit = 512
|
| 114 |
+
|
| 115 |
+
def _get_sbd_inference_quadratic_decoding_block_mask(self, block_length: int):
|
| 116 |
+
if block_length not in self._quadratic_block_mask:
|
| 117 |
+
draft_len = block_length * (block_length + 1)
|
| 118 |
+
|
| 119 |
+
def quadratic(b, h, q_idx, kv_idx):
|
| 120 |
+
first_clean = torch.logical_and(
|
| 121 |
+
kv_idx % (block_length + 1) == 0,
|
| 122 |
+
kv_idx < draft_len,
|
| 123 |
+
)
|
| 124 |
+
first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
|
| 125 |
+
block_q = q_idx // (block_length + 1)
|
| 126 |
+
block_kv = kv_idx // (block_length + 1)
|
| 127 |
+
same_block = torch.logical_and(block_q == block_kv, q_idx < draft_len)
|
| 128 |
+
same_block_except_first = torch.logical_and(
|
| 129 |
+
same_block,
|
| 130 |
+
q_idx % (block_length + 1) != 0,
|
| 131 |
+
)
|
| 132 |
+
draft_part = torch.logical_or(first_clean, same_block_except_first)
|
| 133 |
+
clean_part = kv_idx >= draft_len
|
| 134 |
+
return torch.logical_or(draft_part, clean_part)
|
| 135 |
+
|
| 136 |
+
block_mask = create_block_mask(
|
| 137 |
+
quadratic,
|
| 138 |
+
B=None,
|
| 139 |
+
H=None,
|
| 140 |
+
Q_LEN=draft_len,
|
| 141 |
+
KV_LEN=draft_len + self.config.max_position_embeddings,
|
| 142 |
+
device="cuda",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self._quadratic_block_mask[block_length] = block_mask
|
| 146 |
+
|
| 147 |
+
return self._quadratic_block_mask[block_length]
|
| 148 |
+
|
| 149 |
+
def set_attention_mode(self, mode, block_size=None):
|
| 150 |
+
self.mode = mode
|
| 151 |
+
self.block_size = block_size
|
| 152 |
+
|
| 153 |
+
def compute_block_mask(self, mode, q_len=None, block_size=None):
|
| 154 |
+
|
| 155 |
+
def bidirectional_mask(b, h, q, kv):
|
| 156 |
+
return (q >= kv) | (q < kv)
|
| 157 |
+
|
| 158 |
+
def autoregressive_mask(b, h, q, kv):
|
| 159 |
+
return (q >= kv)
|
| 160 |
+
|
| 161 |
+
def block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
|
| 162 |
+
"""
|
| 163 |
+
Constructs the specialized block diffusion attention mask for training
|
| 164 |
+
composed of three masks:
|
| 165 |
+
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 166 |
+
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 167 |
+
- **Block Causal Mask (M_BC)**: Attention to update x0
|
| 168 |
+
Args:
|
| 169 |
+
b, h: Batch and head indices (ignored for mask logic).
|
| 170 |
+
q_idx, kv_idx: Query and Key indices.
|
| 171 |
+
seq_len: Total sequence length.
|
| 172 |
+
block_size: Defines the block structure.
|
| 173 |
+
Returns:
|
| 174 |
+
A boolean attention mask.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
# Indicate whether token belongs to xt or x0
|
| 178 |
+
x0_flag_q = (q_idx >= n)
|
| 179 |
+
x0_flag_kv = (kv_idx >= n)
|
| 180 |
+
|
| 181 |
+
# Compute block indices
|
| 182 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 183 |
+
(q_idx - n) // block_size,
|
| 184 |
+
q_idx // block_size)
|
| 185 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 186 |
+
(kv_idx - n) // block_size,
|
| 187 |
+
kv_idx // block_size)
|
| 188 |
+
|
| 189 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 190 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 191 |
+
|
| 192 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 193 |
+
offset_block_causal = (
|
| 194 |
+
(block_q > block_kv)
|
| 195 |
+
& (x0_flag_kv == 1)
|
| 196 |
+
& (x0_flag_q == 0)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# **3. Block-Causal Mask (M_BC) **
|
| 200 |
+
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 201 |
+
|
| 202 |
+
# **4. Combine Masks **
|
| 203 |
+
return block_diagonal | offset_block_causal | block_causal
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def sbd_block_diff_mask(block_size, b, h, q_idx, kv_idx, n):
|
| 207 |
+
x0_flag_q = (q_idx >= n)
|
| 208 |
+
x0_flag_kv = (kv_idx >= n)
|
| 209 |
+
|
| 210 |
+
# Compute block indices
|
| 211 |
+
block_q = torch.where(x0_flag_q == 1,
|
| 212 |
+
(q_idx - n) // block_size,
|
| 213 |
+
q_idx // block_size)
|
| 214 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
| 215 |
+
(kv_idx - n) // block_size,
|
| 216 |
+
kv_idx // block_size)
|
| 217 |
+
|
| 218 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 219 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_kv == 0) & (x0_flag_q == 0)
|
| 220 |
+
|
| 221 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 222 |
+
offset_block_causal = (
|
| 223 |
+
(block_q > block_kv)
|
| 224 |
+
& (x0_flag_kv == 1)
|
| 225 |
+
& (x0_flag_q == 0)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# **3. Fully Causal Mask (M_BC) **
|
| 229 |
+
fully_causal = (q_idx >= kv_idx) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 230 |
+
|
| 231 |
+
# **4. Combine Masks **
|
| 232 |
+
return block_diagonal | offset_block_causal | fully_causal
|
| 233 |
+
|
| 234 |
+
if mode == 'bidirectional':
|
| 235 |
+
attn_mask = bidirectional_mask
|
| 236 |
+
elif mode == 'autoregressive':
|
| 237 |
+
attn_mask = autoregressive_mask
|
| 238 |
+
elif mode == 'block_diff':
|
| 239 |
+
assert block_size is not None
|
| 240 |
+
n = (q_len // 2) if q_len is not None else self.max_seq_length
|
| 241 |
+
attn_mask = lambda b, h, q, kv: block_diff_mask(block_size, b, h, q, kv, n)
|
| 242 |
+
elif mode == 'sbd_block_diff':
|
| 243 |
+
assert block_size is not None
|
| 244 |
+
n = (q_len // 2) if q_len is not None else self.max_seq_length
|
| 245 |
+
attn_mask = lambda b, h, q, kv: sbd_block_diff_mask(block_size, b, h, q, kv, n)
|
| 246 |
+
else:
|
| 247 |
+
raise ValueError(f"Unknown attention mode: {mode}")
|
| 248 |
+
|
| 249 |
+
if q_len is not None:
|
| 250 |
+
Q_LEN = q_len
|
| 251 |
+
else:
|
| 252 |
+
if mode in ['block_diff', 'sbd_block_diff']:
|
| 253 |
+
Q_LEN = self.max_seq_length * 2
|
| 254 |
+
else:
|
| 255 |
+
Q_LEN = self.max_seq_length
|
| 256 |
+
|
| 257 |
+
block_mask = create_block_mask(
|
| 258 |
+
attn_mask, B=None, H=None, Q_LEN=Q_LEN, KV_LEN=Q_LEN
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return block_mask
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
hidden_states: torch.Tensor,
|
| 267 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 268 |
+
attention_mask: Optional[torch.Tensor],
|
| 269 |
+
past_key_values: Optional[Cache] = None,
|
| 270 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 271 |
+
is_training: bool = True,
|
| 272 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 273 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 274 |
+
bsz, q_len, _ = hidden_states.size()
|
| 275 |
+
input_shape = hidden_states.shape[:-1]
|
| 276 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 277 |
+
|
| 278 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 279 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 280 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 281 |
+
|
| 282 |
+
cos, sin = position_embeddings
|
| 283 |
+
|
| 284 |
+
if self.mode in ['block_diff', 'sbd_block_diff'] and is_training:
|
| 285 |
+
# Split query and key states in half along sequence length dimension
|
| 286 |
+
q1, q2 = query_states.chunk(2, dim=2)
|
| 287 |
+
k1, k2 = key_states.chunk(2, dim=2)
|
| 288 |
+
|
| 289 |
+
# Apply RoPE independently to each half
|
| 290 |
+
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
| 291 |
+
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
| 292 |
+
|
| 293 |
+
# Recombine the halves
|
| 294 |
+
query_states = torch.cat([q1, q2], dim=2)
|
| 295 |
+
key_states = torch.cat([k1, k2], dim=2)
|
| 296 |
+
else:
|
| 297 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 298 |
+
|
| 299 |
+
query_states = query_states * _get_llama_4_attn_scale(
|
| 300 |
+
cache_position,
|
| 301 |
+
self.config.rope_parameters.get("llama_4_scaling_beta"),
|
| 302 |
+
self.config.rope_parameters.get("original_max_position_embeddings"),
|
| 303 |
+
).to(query_states.dtype)
|
| 304 |
+
|
| 305 |
+
if past_key_values is not None:
|
| 306 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 307 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 308 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 309 |
+
|
| 310 |
+
self_spec_inference_mode = getattr(self.config, "self_spec_inference_mode", None)
|
| 311 |
+
if self_spec_inference_mode is not None:
|
| 312 |
+
if self_spec_inference_mode == "quadratic":
|
| 313 |
+
block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
|
| 314 |
+
if block_length is None:
|
| 315 |
+
raise ValueError("SBD quadratic decoding requires block_length in config.")
|
| 316 |
+
if past_key_values is not None:
|
| 317 |
+
seq_len = key_states.shape[2]
|
| 318 |
+
draft_len = block_length * (block_length + 1)
|
| 319 |
+
|
| 320 |
+
clean_keys = key_states[:, :, :-draft_len]
|
| 321 |
+
draft_keys = key_states[:, :, -draft_len:]
|
| 322 |
+
clean_values = value_states[:, :, :-draft_len]
|
| 323 |
+
draft_values = value_states[:, :, -draft_len:]
|
| 324 |
+
key_states = torch.cat([draft_keys, clean_keys], dim=2)
|
| 325 |
+
value_states = torch.cat([draft_values, clean_values], dim=2)
|
| 326 |
+
|
| 327 |
+
block_mask = self._get_sbd_inference_quadratic_decoding_block_mask(
|
| 328 |
+
block_length=block_length
|
| 329 |
+
)
|
| 330 |
+
block_mask.seq_lengths = (draft_len, seq_len)
|
| 331 |
+
else:
|
| 332 |
+
seq_len = query_states.shape[2]
|
| 333 |
+
draft_len = block_length * (block_length + 1)
|
| 334 |
+
clean_len = seq_len - draft_len
|
| 335 |
+
|
| 336 |
+
def _causal_mask(b, h, q_idx, kv_idx):
|
| 337 |
+
return torch.logical_and(q_idx >= kv_idx, q_idx < clean_len)
|
| 338 |
+
|
| 339 |
+
def _draft2clean_mask(b, h, q_idx, kv_idx):
|
| 340 |
+
full_clean = torch.logical_and(q_idx >= clean_len, kv_idx <= clean_len)
|
| 341 |
+
first_clean = torch.logical_and(
|
| 342 |
+
q_idx >= clean_len, (kv_idx - clean_len) % (block_length + 1) == 0
|
| 343 |
+
)
|
| 344 |
+
first_clean = torch.logical_and(first_clean, q_idx >= kv_idx)
|
| 345 |
+
return torch.logical_or(full_clean, first_clean)
|
| 346 |
+
|
| 347 |
+
def _draft_mask(b, h, q_idx, kv_idx):
|
| 348 |
+
block_q = (q_idx - clean_len) // (block_length + 1)
|
| 349 |
+
block_kv = (kv_idx - clean_len) // (block_length + 1)
|
| 350 |
+
quadrant = torch.logical_and(q_idx >= clean_len, kv_idx >= clean_len)
|
| 351 |
+
same_block = torch.logical_and(block_q == block_kv, quadrant)
|
| 352 |
+
same_block_except_first = torch.logical_and(
|
| 353 |
+
same_block,
|
| 354 |
+
(q_idx - clean_len) % (block_length + 1) != 0,
|
| 355 |
+
)
|
| 356 |
+
return torch.logical_and(block_q == block_kv, same_block_except_first)
|
| 357 |
+
|
| 358 |
+
mask = or_masks(_causal_mask, _draft2clean_mask)
|
| 359 |
+
mask = or_masks(mask, _draft_mask)
|
| 360 |
+
|
| 361 |
+
block_mask = create_block_mask(
|
| 362 |
+
mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 366 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 367 |
+
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 368 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 369 |
+
attn_output = self.o_proj(attn_output)
|
| 370 |
+
return attn_output, None
|
| 371 |
+
|
| 372 |
+
elif self_spec_inference_mode == "default":
|
| 373 |
+
block_length = getattr(self.config, "block_length", None) or getattr(self.config, "block_size", None)
|
| 374 |
+
if block_length is None:
|
| 375 |
+
raise ValueError("SBD default decoding requires block_length in config.")
|
| 376 |
+
seq_len = query_states.shape[2]
|
| 377 |
+
prefix_len = seq_len - block_length
|
| 378 |
+
|
| 379 |
+
def _clean_q_mask(b, h, q_idx, kv_idx):
|
| 380 |
+
return torch.logical_and(q_idx >= kv_idx, q_idx < prefix_len)
|
| 381 |
+
|
| 382 |
+
def _noisy_q_mask(b, h, q_idx, kv_idx):
|
| 383 |
+
return q_idx >= prefix_len
|
| 384 |
+
|
| 385 |
+
block_mask = create_block_mask(
|
| 386 |
+
or_masks(_clean_q_mask, _noisy_q_mask),
|
| 387 |
+
B=None,
|
| 388 |
+
H=None,
|
| 389 |
+
Q_LEN=seq_len,
|
| 390 |
+
KV_LEN=seq_len,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 394 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 395 |
+
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 396 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 397 |
+
attn_output = self.o_proj(attn_output)
|
| 398 |
+
return attn_output, None
|
| 399 |
+
|
| 400 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 401 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 402 |
+
|
| 403 |
+
if self.mode == 'bidirectional':
|
| 404 |
+
if self.bidirectional_mask is None or q_len != self.bidirectional_mask.shape[-2]:
|
| 405 |
+
block_mask = self.compute_block_mask(mode='bidirectional', q_len=q_len)
|
| 406 |
+
else:
|
| 407 |
+
block_mask = self.bidirectional_mask
|
| 408 |
+
|
| 409 |
+
elif self.mode == 'autoregressive':
|
| 410 |
+
if self.autoregressive_mask is None or q_len != self.autoregressive_mask.shape[-2]:
|
| 411 |
+
block_mask = self.compute_block_mask(mode='autoregressive', q_len=q_len)
|
| 412 |
+
else:
|
| 413 |
+
block_mask = self.autoregressive_mask
|
| 414 |
+
|
| 415 |
+
elif self.mode == 'block_diff':
|
| 416 |
+
if self.block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.block_diff_mask.shape[-2]:
|
| 417 |
+
block_mask = self.compute_block_mask(mode='block_diff', block_size=self.block_size, q_len=q_len)
|
| 418 |
+
else:
|
| 419 |
+
block_mask = self.block_diff_mask
|
| 420 |
+
elif self.mode == 'sbd_block_diff':
|
| 421 |
+
if self.sbd_block_diff_mask is None or self.block_size != self.block_size_orig or q_len != self.sbd_block_diff_mask.shape[-2]:
|
| 422 |
+
block_mask = self.compute_block_mask(mode='sbd_block_diff', block_size=self.block_size, q_len=q_len)
|
| 423 |
+
else:
|
| 424 |
+
block_mask = self.sbd_block_diff_mask
|
| 425 |
+
else:
|
| 426 |
+
raise ValueError(f"Unknown attention mode: {self.mode}")
|
| 427 |
+
|
| 428 |
+
attn_output = fused_flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
| 429 |
+
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
| 430 |
+
|
| 431 |
+
attn_output = self.o_proj(attn_output)
|
| 432 |
+
|
| 433 |
+
return attn_output, None
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
|
| 437 |
+
"""Return a Bool mask of length len(log_w) with exactly k True."""
|
| 438 |
+
g = -torch.log(-torch.log(torch.rand_like(log_w) + 1e-9) + 1e-9)
|
| 439 |
+
topk = torch.topk(log_w + g, k).indices
|
| 440 |
+
mask = torch.zeros_like(log_w, dtype=torch.bool)
|
| 441 |
+
mask[topk] = True
|
| 442 |
+
return mask
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class NemotronLabsDiffusionVLMModel(Ministral3PreTrainedModel, GenerationMixin):
|
| 446 |
+
"""
|
| 447 |
+
A single model with:
|
| 448 |
+
- a bidirectional encoder + diffusion‐LM head over A
|
| 449 |
+
- a causal decoder + LM head over B, conditioned on F_A
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
def __init__(self, config: NemotronLabsDiffusionVLMConfig):
|
| 453 |
+
super().__init__(config)
|
| 454 |
+
|
| 455 |
+
self.mask_token_id = config.mask_token_id
|
| 456 |
+
|
| 457 |
+
diffusion_config = copy.deepcopy(config)
|
| 458 |
+
diffusion_config.diffusion_lm = True
|
| 459 |
+
|
| 460 |
+
use_flex = getattr(config, 'enable_self_spec', False)
|
| 461 |
+
|
| 462 |
+
if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 463 |
+
diffusion_config.attn_class = NemotronLabsDiffusionVLMFlexAttention
|
| 464 |
+
elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
|
| 465 |
+
diffusion_config.attn_class = NemotronLabsDiffusionVLMFlexAttention if use_flex else Ministral3Attention
|
| 466 |
+
if config.dlm_paradigm == 'autoregressive':
|
| 467 |
+
diffusion_config.diffusion_lm = False
|
| 468 |
+
else:
|
| 469 |
+
raise ValueError(f"Unsupported DLM paradigm: {config.dlm_paradigm}")
|
| 470 |
+
|
| 471 |
+
self.encoder = Ministral3Model(diffusion_config)
|
| 472 |
+
self.diffusion_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 473 |
+
self.vocab_size = config.vocab_size
|
| 474 |
+
|
| 475 |
+
self.current_iter_ratio = None
|
| 476 |
+
self.mdm_loss_function = LOSS_MAPPING['ForMaskedLM']
|
| 477 |
+
self.causal_loss_function = LOSS_MAPPING['ForCausalLM']
|
| 478 |
+
|
| 479 |
+
self.post_init()
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def get_input_embeddings(self):
|
| 483 |
+
return self.encoder.embed_tokens
|
| 484 |
+
|
| 485 |
+
def set_input_embeddings(self, value):
|
| 486 |
+
self.encoder.embed_tokens = value
|
| 487 |
+
|
| 488 |
+
def get_output_embeddings(self):
|
| 489 |
+
return self.diffusion_head
|
| 490 |
+
|
| 491 |
+
def set_output_embeddings(self, new_embeddings):
|
| 492 |
+
self.diffusion_head = new_embeddings
|
| 493 |
+
|
| 494 |
+
def forward_process_complementary(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
|
| 495 |
+
device = input_ids.device
|
| 496 |
+
|
| 497 |
+
if self.config.dp_varying_mask_ratio:
|
| 498 |
+
import torch.distributed as dist
|
| 499 |
+
dp_rank = 0
|
| 500 |
+
if dist.is_initialized():
|
| 501 |
+
try:
|
| 502 |
+
dp_rank = dist.get_rank()
|
| 503 |
+
except Exception:
|
| 504 |
+
dp_rank = 0
|
| 505 |
+
generator = torch.Generator(device=device)
|
| 506 |
+
generator.manual_seed(torch.seed() + dp_rank)
|
| 507 |
+
else:
|
| 508 |
+
generator = None
|
| 509 |
+
|
| 510 |
+
noisy_input_ids = input_ids.clone()
|
| 511 |
+
input_ids_flat = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // block_size, block_size)
|
| 512 |
+
b, l = input_ids_flat.shape
|
| 513 |
+
t = torch.rand((b,), device=input_ids.device, generator=generator)
|
| 514 |
+
p_mask = (1 - eps) * t + eps
|
| 515 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
| 516 |
+
|
| 517 |
+
masked_indices = (torch.rand((b, l), device=input_ids.device, generator=generator) < p_mask).reshape(noisy_input_ids.shape)
|
| 518 |
+
input_ids_flat = input_ids_flat.reshape(noisy_input_ids.shape)
|
| 519 |
+
|
| 520 |
+
complementary_noisy_input_ids = input_ids.clone()
|
| 521 |
+
complementary_masked_indices = ~masked_indices
|
| 522 |
+
|
| 523 |
+
if getattr(self.config, 'always_mask_im_end', False):
|
| 524 |
+
im_end_mask = (input_ids == self.config.im_end_token_id)
|
| 525 |
+
masked_indices = masked_indices | im_end_mask
|
| 526 |
+
complementary_masked_indices = complementary_masked_indices | im_end_mask
|
| 527 |
+
|
| 528 |
+
if loss_mask is not None:
|
| 529 |
+
masked_indices[loss_mask == 0] = 0
|
| 530 |
+
complementary_masked_indices[loss_mask == 0] = 0
|
| 531 |
+
|
| 532 |
+
noisy_input_ids[masked_indices] = self.mask_token_id
|
| 533 |
+
complementary_noisy_input_ids[complementary_masked_indices] = self.mask_token_id
|
| 534 |
+
|
| 535 |
+
noisy_input_ids = torch.cat([noisy_input_ids, complementary_noisy_input_ids], dim=0)
|
| 536 |
+
masked_indices = torch.cat([masked_indices, complementary_masked_indices], dim=0)
|
| 537 |
+
return noisy_input_ids, masked_indices, None
|
| 538 |
+
|
| 539 |
+
# ── Vision / multimodal helpers (ported from Mistral3Model) ──────────
|
| 540 |
+
|
| 541 |
+
IMAGE_TOKEN_ID = 19
|
| 542 |
+
|
| 543 |
+
def get_image_features(
|
| 544 |
+
self,
|
| 545 |
+
pixel_values: torch.FloatTensor,
|
| 546 |
+
image_sizes: torch.Tensor,
|
| 547 |
+
) -> torch.FloatTensor:
|
| 548 |
+
"""
|
| 549 |
+
Run the vision tower + multimodal projector and return a flat tensor
|
| 550 |
+
of image features ready to be scattered into the text embeddings.
|
| 551 |
+
|
| 552 |
+
Mirrors ``Mistral3Model.get_image_features`` from
|
| 553 |
+
transformers/models/mistral3/modeling_mistral3.py.
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
Flat (total_image_tokens, hidden_size) tensor.
|
| 557 |
+
"""
|
| 558 |
+
vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
|
| 559 |
+
|
| 560 |
+
image_outputs = self.encoder.vision_tower(
|
| 561 |
+
pixel_values,
|
| 562 |
+
image_sizes=image_sizes,
|
| 563 |
+
output_hidden_states=True,
|
| 564 |
+
return_dict=True,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if isinstance(vision_feature_layer, int):
|
| 568 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
| 569 |
+
else:
|
| 570 |
+
hs_pool = [image_outputs.hidden_states[idx] for idx in vision_feature_layer]
|
| 571 |
+
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
| 572 |
+
|
| 573 |
+
image_features = self.encoder.multi_modal_projector(
|
| 574 |
+
selected_image_feature.squeeze(0), image_sizes,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Split per image, then re-cat into one flat tensor
|
| 578 |
+
downsample_ratio = (
|
| 579 |
+
self.encoder.vision_tower.patch_size
|
| 580 |
+
* getattr(self.config, "spatial_merge_size", 2)
|
| 581 |
+
)
|
| 582 |
+
split_sizes = (
|
| 583 |
+
(torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio)
|
| 584 |
+
.prod(dim=-1)
|
| 585 |
+
.tolist()
|
| 586 |
+
)
|
| 587 |
+
# per_image = torch.split(image_features.squeeze(0), split_sizes)
|
| 588 |
+
per_image = torch.split(image_features, split_sizes)
|
| 589 |
+
|
| 590 |
+
return torch.cat(per_image, dim=0) # (total_tokens, hidden)
|
| 591 |
+
|
| 592 |
+
def _is_vision_frozen(self) -> bool:
|
| 593 |
+
"""True if vision_tower and multi_modal_projector have no parameters requiring grad (e.g. --freeze_vision_encoder)."""
|
| 594 |
+
vt = self.encoder.vision_tower
|
| 595 |
+
proj = self.encoder.multi_modal_projector
|
| 596 |
+
vt_has_grad = any(p.requires_grad for p in vt.parameters())
|
| 597 |
+
proj_has_grad = any(p.requires_grad for p in proj.parameters())
|
| 598 |
+
return not vt_has_grad and not proj_has_grad
|
| 599 |
+
|
| 600 |
+
def _embed_with_vision(
|
| 601 |
+
self,
|
| 602 |
+
input_ids: torch.LongTensor,
|
| 603 |
+
pixel_values: torch.FloatTensor,
|
| 604 |
+
image_sizes: torch.Tensor,
|
| 605 |
+
) -> torch.FloatTensor:
|
| 606 |
+
"""
|
| 607 |
+
Embed *input_ids* and scatter vision features into [IMG] pad positions.
|
| 608 |
+
|
| 609 |
+
Returns:
|
| 610 |
+
inputs_embeds (batch, seq_len, hidden_size)
|
| 611 |
+
"""
|
| 612 |
+
inputs_embeds = self.encoder.embed_tokens(input_ids)
|
| 613 |
+
image_features = self.get_image_features(pixel_values, image_sizes)
|
| 614 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 615 |
+
|
| 616 |
+
# Boolean mask: positions that are IMG pad tokens
|
| 617 |
+
special_image_mask = (input_ids == self.IMAGE_TOKEN_ID)
|
| 618 |
+
|
| 619 |
+
if self.training:
|
| 620 |
+
if self.config.complementary_mask:
|
| 621 |
+
image_features = image_features.repeat(2, 1)
|
| 622 |
+
if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 623 |
+
image_features = image_features.repeat(2, 1)
|
| 624 |
+
|
| 625 |
+
assert special_image_mask.sum() == image_features.shape[0], f"special_image_mask.sum() = {special_image_mask.sum()}, image_features.shape[0] = {image_features.shape[0]}"
|
| 626 |
+
# Expand to hidden dim for masked_scatter
|
| 627 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds)
|
| 628 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 629 |
+
return inputs_embeds
|
| 630 |
+
|
| 631 |
+
def forward_process(self, input_ids, eps=1e-3, block_size=None, loss_mask=None):
|
| 632 |
+
b, l = input_ids.shape
|
| 633 |
+
device = input_ids.device
|
| 634 |
+
|
| 635 |
+
if self.config.dp_varying_mask_ratio:
|
| 636 |
+
# Enable different random seeds for each DP rank during sampling
|
| 637 |
+
import torch.distributed as dist
|
| 638 |
+
dp_rank = 0
|
| 639 |
+
if dist.is_initialized():
|
| 640 |
+
try:
|
| 641 |
+
dp_rank = dist.get_rank()
|
| 642 |
+
except Exception:
|
| 643 |
+
dp_rank = 0
|
| 644 |
+
# Use a local generator to avoid affecting global RNG state
|
| 645 |
+
generator = torch.Generator(device=device)
|
| 646 |
+
generator.manual_seed(torch.seed() + dp_rank)
|
| 647 |
+
else:
|
| 648 |
+
generator = None
|
| 649 |
+
|
| 650 |
+
if self.config.adaptive_mask_rate:
|
| 651 |
+
assert block_size is not None
|
| 652 |
+
|
| 653 |
+
# --- simple linear window mapping ---
|
| 654 |
+
bs_min = getattr(self.config, "t_bs_min", 16)
|
| 655 |
+
bs_max = getattr(self.config, "t_bs_max", 128)
|
| 656 |
+
w = getattr(self.config, "t_window_width", 0.6) # fixed width
|
| 657 |
+
|
| 658 |
+
# fraction in [0,1] (unclamped first)
|
| 659 |
+
frac = (float(block_size) - float(bs_min)) / max(1.0, float(bs_max - bs_min))
|
| 660 |
+
# upper bound decreases linearly from 1.0 -> 0.5
|
| 661 |
+
u_max = 1.0 - w * frac
|
| 662 |
+
# clamp to [0.6, 1.0] to handle bs outside [bs_min, bs_max]
|
| 663 |
+
u_max = max(0.6, min(1.0, u_max))
|
| 664 |
+
u_min = u_max - w # ensures width = w
|
| 665 |
+
|
| 666 |
+
# sample t ~ Uniform(u_min, u_max)
|
| 667 |
+
t = u_min + (u_max - u_min) * torch.rand(b, device=device, generator=generator)
|
| 668 |
+
else:
|
| 669 |
+
t = torch.rand(b, device=device, generator=generator)
|
| 670 |
+
|
| 671 |
+
p_mask = (1 - eps) * t + eps # shape: (b,)
|
| 672 |
+
p_mask = p_mask[:, None].expand(-1, l) # shape: (b, l)
|
| 673 |
+
|
| 674 |
+
masked_indices = torch.rand((b, l), device=device) < p_mask
|
| 675 |
+
|
| 676 |
+
if loss_mask is not None:
|
| 677 |
+
masked_indices[loss_mask == 0] = 0
|
| 678 |
+
|
| 679 |
+
noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 680 |
+
|
| 681 |
+
return noisy_batch, masked_indices, p_mask
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def forward_process_exp(
|
| 685 |
+
self,
|
| 686 |
+
input_ids: torch.Tensor,
|
| 687 |
+
eps: float = 1e-3,
|
| 688 |
+
block_size: int | None = None,
|
| 689 |
+
half_life_ratio: float = 0.25, # λ = ln 2 / (half_life_ratio·L)
|
| 690 |
+
loss_mask: Optional[torch.Tensor] = None,
|
| 691 |
+
):
|
| 692 |
+
"""
|
| 693 |
+
Two-stage corruption with optional per-block sampling.
|
| 694 |
+
• Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
|
| 695 |
+
• Stage 2: sample exactly k positions with weights
|
| 696 |
+
w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
|
| 697 |
+
uniform when m→1).
|
| 698 |
+
If `block_size` is given, the procedure is run *independently*
|
| 699 |
+
inside each contiguous block of that length (last block may be shorter).
|
| 700 |
+
When block_size is provided, m is sampled per-block and p_mask is per-block.
|
| 701 |
+
Args
|
| 702 |
+
----
|
| 703 |
+
input_ids : (B, L) LongTensor
|
| 704 |
+
eps : minimum corruption ratio
|
| 705 |
+
block_size: if not None, operate block-wise with per-block m sampling
|
| 706 |
+
half_life_ratio : controls steepness when m→0
|
| 707 |
+
"""
|
| 708 |
+
B, L = input_ids.shape
|
| 709 |
+
device = input_ids.device
|
| 710 |
+
dtype = torch.float32
|
| 711 |
+
|
| 712 |
+
masked_indices = torch.zeros((B, L), dtype=torch.bool, device=device)
|
| 713 |
+
p_mask = torch.zeros((B, L), dtype=dtype, device=device)
|
| 714 |
+
|
| 715 |
+
# ---------- Stage 1 & 2: whole-sentence or block-wise -------------------
|
| 716 |
+
for b in range(B):
|
| 717 |
+
if block_size is None:
|
| 718 |
+
# ---------- Per-batch sampling (original behavior) ----------
|
| 719 |
+
m = eps + (1.0 - eps) * torch.rand(1, device=device).item() # scalar
|
| 720 |
+
k_tot = int(round(m * L))
|
| 721 |
+
k_tot = max(1, min(k_tot, L)) # clamp to [1, L]
|
| 722 |
+
|
| 723 |
+
# Fill p_mask for this batch
|
| 724 |
+
p_mask[b, :] = m
|
| 725 |
+
|
| 726 |
+
slope = 1.0 - m # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
|
| 727 |
+
|
| 728 |
+
# ------- single pool over the whole sentence -------------
|
| 729 |
+
lam_base = math.log(2.0) / (half_life_ratio * L) # base decay rate (λ when slope=1)
|
| 730 |
+
|
| 731 |
+
pos = torch.arange(L, device=device, dtype=dtype)
|
| 732 |
+
log_w = (lam_base * slope * pos).clone()
|
| 733 |
+
|
| 734 |
+
masked_indices[b] = gumbel_topk(log_w, k_tot)
|
| 735 |
+
|
| 736 |
+
else:
|
| 737 |
+
# ---------- Per-block sampling ----------
|
| 738 |
+
num_blocks = math.ceil(L / block_size)
|
| 739 |
+
lam_base = math.log(2.0) / (half_life_ratio * block_size) # base decay rate (λ when slope=1)
|
| 740 |
+
|
| 741 |
+
for blk in range(num_blocks):
|
| 742 |
+
start = blk * block_size
|
| 743 |
+
end = min((blk + 1) * block_size, L)
|
| 744 |
+
blk_len = end - start
|
| 745 |
+
|
| 746 |
+
# Sample m per block
|
| 747 |
+
m_blk = eps + (1.0 - eps) * torch.rand(1, device=device).item()
|
| 748 |
+
|
| 749 |
+
# Fill p_mask for this block
|
| 750 |
+
p_mask[b, start:end] = m_blk
|
| 751 |
+
|
| 752 |
+
# per-block budget
|
| 753 |
+
k_blk = int(round(m_blk * blk_len))
|
| 754 |
+
k_blk = max(0, min(k_blk, blk_len))
|
| 755 |
+
if k_blk == 0:
|
| 756 |
+
continue
|
| 757 |
+
|
| 758 |
+
slope = 1.0 - m_blk # ∈ [0,1]; 0 ⇒ uniform, 1 ⇒ late-heavy
|
| 759 |
+
|
| 760 |
+
pos = torch.arange(blk_len, device=device, dtype=dtype)
|
| 761 |
+
log_w = lam_base * slope * pos
|
| 762 |
+
|
| 763 |
+
blk_mask = gumbel_topk(log_w, k_blk)
|
| 764 |
+
masked_indices[b, start:end] = blk_mask
|
| 765 |
+
|
| 766 |
+
if loss_mask is not None:
|
| 767 |
+
masked_indices[loss_mask == 0] = 0
|
| 768 |
+
|
| 769 |
+
noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 770 |
+
return noisy_batch, masked_indices, p_mask
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
def forward(
|
| 774 |
+
self,
|
| 775 |
+
input_ids: torch.LongTensor,
|
| 776 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 777 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 778 |
+
labels: Optional[torch.LongTensor] = None,
|
| 779 |
+
split_len: Optional[int] = None,
|
| 780 |
+
past_key_values: Optional[Cache] = None,
|
| 781 |
+
block_size: Optional[int] = None,
|
| 782 |
+
block_diff_ppl: bool = False,
|
| 783 |
+
eps: float = 1e-3,
|
| 784 |
+
is_teacher: bool = False,
|
| 785 |
+
masked_indices: Optional[torch.Tensor] = None,
|
| 786 |
+
p_mask: Optional[torch.Tensor] = None,
|
| 787 |
+
teacher_logits: Optional[torch.Tensor] = None,
|
| 788 |
+
masked_indices_teacher: Optional[torch.Tensor] = None,
|
| 789 |
+
loss_mask: Optional[torch.Tensor] = None,
|
| 790 |
+
ce_loss_weight: float = 1.0,
|
| 791 |
+
output_last_hidden_states_only: bool = False,
|
| 792 |
+
skip_loss: bool = False,
|
| 793 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 794 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 795 |
+
**kwargs,
|
| 796 |
+
) -> CausalLMOutputWithPast:
|
| 797 |
+
|
| 798 |
+
batch_size, seq_len = input_ids.shape
|
| 799 |
+
|
| 800 |
+
if self.config.dlm_paradigm == 'bidirectional' or self.config.dlm_paradigm == 'autoregressive':
|
| 801 |
+
if labels is not None and torch.rand(1) < self.config.random_length_prob:
|
| 802 |
+
random_length = torch.randint(2, input_ids.shape[1] + 1, (1,))
|
| 803 |
+
input_ids = input_ids[:, :random_length]
|
| 804 |
+
labels = labels[:, :random_length]
|
| 805 |
+
|
| 806 |
+
if attention_mask is not None:
|
| 807 |
+
attention_mask = attention_mask[:, :random_length]
|
| 808 |
+
if position_ids is not None:
|
| 809 |
+
position_ids = position_ids[:, :random_length]
|
| 810 |
+
if loss_mask is not None:
|
| 811 |
+
loss_mask = loss_mask[:, :random_length]
|
| 812 |
+
|
| 813 |
+
elif self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 814 |
+
if labels is not None and block_size is None:
|
| 815 |
+
if torch.rand(1) < self.config.random_length_prob:
|
| 816 |
+
block_size = torch.randint(1, 8, (1,)).item() * 4 ## [4, 32] divisible by 4
|
| 817 |
+
else:
|
| 818 |
+
block_size = self.config.block_size
|
| 819 |
+
|
| 820 |
+
else:
|
| 821 |
+
raise ValueError(f"Unknown dLM paradigm: {self.config.dlm_paradigm}")
|
| 822 |
+
|
| 823 |
+
if labels is not None and self.config.dlm_paradigm != 'autoregressive':
|
| 824 |
+
if masked_indices is not None:
|
| 825 |
+
# assert p_mask is not None
|
| 826 |
+
|
| 827 |
+
if loss_mask is not None:
|
| 828 |
+
masked_indices[loss_mask == 0] = 0
|
| 829 |
+
|
| 830 |
+
noisy_inputs = torch.where(masked_indices, self.mask_token_id, input_ids)
|
| 831 |
+
|
| 832 |
+
else:
|
| 833 |
+
if self.config.complementary_mask:
|
| 834 |
+
loss_mask = (labels != -100)
|
| 835 |
+
noisy_inputs, masked_indices, p_mask = self.forward_process_complementary(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
|
| 836 |
+
else:
|
| 837 |
+
if self.config.tok_mask_half_life_ratio is not None:
|
| 838 |
+
noisy_inputs, masked_indices, p_mask = self.forward_process_exp(input_ids, eps=eps, block_size=block_size, half_life_ratio=self.config.tok_mask_half_life_ratio, loss_mask=loss_mask)
|
| 839 |
+
else:
|
| 840 |
+
noisy_inputs, masked_indices, p_mask = self.forward_process(input_ids, eps=eps, block_size=block_size, loss_mask=loss_mask)
|
| 841 |
+
|
| 842 |
+
else:
|
| 843 |
+
noisy_inputs = input_ids
|
| 844 |
+
masked_indices = None
|
| 845 |
+
p_mask = None
|
| 846 |
+
|
| 847 |
+
if self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 848 |
+
for layer in self.encoder.layers:
|
| 849 |
+
if hasattr(layer.self_attn, 'set_attention_mode'):
|
| 850 |
+
layer.self_attn.set_attention_mode(self.config.dlm_paradigm, block_size=block_size)
|
| 851 |
+
|
| 852 |
+
input_ids_len = noisy_inputs.shape[1]
|
| 853 |
+
if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 854 |
+
if position_ids is None:
|
| 855 |
+
position_ids = torch.arange(input_ids_len, device=noisy_inputs.device).unsqueeze(0)
|
| 856 |
+
if self.config.complementary_mask:
|
| 857 |
+
noisy_inputs = torch.cat([noisy_inputs, torch.cat([input_ids, input_ids], dim=0)], dim=1)
|
| 858 |
+
else:
|
| 859 |
+
noisy_inputs = torch.cat([noisy_inputs, input_ids], dim=1)
|
| 860 |
+
|
| 861 |
+
if block_diff_ppl:
|
| 862 |
+
if position_ids is None:
|
| 863 |
+
position_ids = torch.arange(input_ids_len // 2, device=noisy_inputs.device).unsqueeze(0)
|
| 864 |
+
|
| 865 |
+
# ── Vision: replace IMG pad embeddings with image features ────────
|
| 866 |
+
if pixel_values is not None and image_sizes is not None:
|
| 867 |
+
inputs_embeds = self._embed_with_vision(noisy_inputs, pixel_values, image_sizes)
|
| 868 |
+
enc_out = self.encoder(
|
| 869 |
+
past_key_values=past_key_values,
|
| 870 |
+
inputs_embeds=inputs_embeds,
|
| 871 |
+
attention_mask=attention_mask,
|
| 872 |
+
position_ids=position_ids,
|
| 873 |
+
is_training=(labels is not None) or (block_diff_ppl),
|
| 874 |
+
**kwargs,
|
| 875 |
+
)
|
| 876 |
+
elif self.training and pixel_values is None and not self._is_vision_frozen():
|
| 877 |
+
vt = self.encoder.vision_tower
|
| 878 |
+
_p = vt.patch_size
|
| 879 |
+
_merge = getattr(self.config, "spatial_merge_size", 2)
|
| 880 |
+
_side = _p * _merge
|
| 881 |
+
_c = getattr(vt.config, "num_channels", 3)
|
| 882 |
+
_dtype = next(vt.parameters()).dtype
|
| 883 |
+
dummy_pixel = torch.zeros(
|
| 884 |
+
1, _c, _side, _side,
|
| 885 |
+
dtype=_dtype, device=noisy_inputs.device,
|
| 886 |
+
)
|
| 887 |
+
dummy_image_sizes = torch.tensor(
|
| 888 |
+
[(int(_side), int(_side))],
|
| 889 |
+
dtype=torch.long, device=noisy_inputs.device,
|
| 890 |
+
)
|
| 891 |
+
dummy_features = self.get_image_features(dummy_pixel, dummy_image_sizes)
|
| 892 |
+
inputs_embeds = self.encoder.embed_tokens(noisy_inputs)
|
| 893 |
+
inputs_embeds = inputs_embeds + dummy_features.sum() * 0
|
| 894 |
+
enc_out = self.encoder(
|
| 895 |
+
past_key_values=past_key_values,
|
| 896 |
+
inputs_embeds=inputs_embeds,
|
| 897 |
+
attention_mask=attention_mask,
|
| 898 |
+
position_ids=position_ids,
|
| 899 |
+
is_training=(labels is not None) or (block_diff_ppl),
|
| 900 |
+
**kwargs,
|
| 901 |
+
)
|
| 902 |
+
else:
|
| 903 |
+
enc_out = self.encoder(
|
| 904 |
+
past_key_values=past_key_values,
|
| 905 |
+
input_ids=noisy_inputs,
|
| 906 |
+
attention_mask=attention_mask,
|
| 907 |
+
position_ids=position_ids,
|
| 908 |
+
is_training=(labels is not None) or (block_diff_ppl),
|
| 909 |
+
**kwargs,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
if output_last_hidden_states_only:
|
| 913 |
+
return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
|
| 914 |
+
|
| 915 |
+
logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
|
| 916 |
+
causal_logits = None
|
| 917 |
+
|
| 918 |
+
if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
|
| 919 |
+
if self.config.dlm_paradigm == 'sbd_block_diff':
|
| 920 |
+
causal_logits = logits[:, input_ids_len:]
|
| 921 |
+
else:
|
| 922 |
+
causal_logits = None
|
| 923 |
+
|
| 924 |
+
logits = logits[:, :input_ids_len]
|
| 925 |
+
|
| 926 |
+
loss = None
|
| 927 |
+
if getattr(self.config, 'complementary_mask', False) and self.config.dlm_paradigm == 'sbd_block_diff':
|
| 928 |
+
_raw_nib = kwargs.get('num_items_in_batch', None)
|
| 929 |
+
kwargs = {**kwargs, 'num_items_in_batch': 2 * kwargs.get('num_items_in_batch', 1)}
|
| 930 |
+
if self.training and (not hasattr(self, '_nib_logged') or not self._nib_logged):
|
| 931 |
+
import torch.distributed as dist
|
| 932 |
+
_rank = dist.get_rank() if dist.is_initialized() else 0
|
| 933 |
+
if _rank == 0:
|
| 934 |
+
print(f"[DEBUG-NIB] raw num_items_in_batch from Trainer: {_raw_nib}, "
|
| 935 |
+
f"after 2x: {kwargs['num_items_in_batch']}, "
|
| 936 |
+
f"labels non-(-100): {(labels != -100).sum().item() if labels is not None else 'N/A'}, "
|
| 937 |
+
f"batch_size={input_ids.shape[0]}, seq_len={input_ids.shape[1]}", flush=True)
|
| 938 |
+
self._nib_logged = True
|
| 939 |
+
if labels is not None and not skip_loss:
|
| 940 |
+
if self.config.dlm_paradigm == 'autoregressive':
|
| 941 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 942 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 943 |
+
|
| 944 |
+
if loss_mask is None:
|
| 945 |
+
loss_fct = CrossEntropyLoss()
|
| 946 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 947 |
+
shift_labels = shift_labels.view(-1)
|
| 948 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 949 |
+
|
| 950 |
+
else:
|
| 951 |
+
loss_mask = loss_mask[..., 1:].contiguous()
|
| 952 |
+
|
| 953 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
| 954 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 955 |
+
shift_labels = shift_labels.view(-1)
|
| 956 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 957 |
+
|
| 958 |
+
token_losses = loss_fct(shift_logits, shift_labels)
|
| 959 |
+
|
| 960 |
+
flat_loss_mask = loss_mask.reshape(-1)
|
| 961 |
+
loss = token_losses[flat_loss_mask == 1].sum() / flat_loss_mask.sum()
|
| 962 |
+
|
| 963 |
+
else:
|
| 964 |
+
# Handle DREAM vs LLADA style losses
|
| 965 |
+
if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
|
| 966 |
+
logits = logits[..., :-1, :].contiguous()
|
| 967 |
+
labels = labels[..., 1:].contiguous()
|
| 968 |
+
masked_indices = masked_indices[:, 1:]
|
| 969 |
+
if p_mask is not None:
|
| 970 |
+
p_mask = p_mask[:, 1:]
|
| 971 |
+
|
| 972 |
+
if self.config.ada_perm_ratio_per_block is not None:
|
| 973 |
+
# Only compute loss for the top ada_perm_ratio_per_block tokens by confidence within each block
|
| 974 |
+
block_size = self.config.block_size
|
| 975 |
+
batch_size, seq_len = masked_indices.shape
|
| 976 |
+
num_blocks = seq_len // block_size
|
| 977 |
+
|
| 978 |
+
# Get the max logit (confidence) for each position
|
| 979 |
+
confidence = logits.max(dim=-1).values.detach() # (batch_size, seq_len)
|
| 980 |
+
|
| 981 |
+
# Create a mask for tokens to include in loss
|
| 982 |
+
selected_mask = torch.zeros_like(masked_indices, dtype=torch.bool)
|
| 983 |
+
|
| 984 |
+
for blk in range(num_blocks):
|
| 985 |
+
start = blk * block_size
|
| 986 |
+
end = min((blk + 1) * block_size, seq_len)
|
| 987 |
+
|
| 988 |
+
# Get masked indices within this block
|
| 989 |
+
block_masked = masked_indices[:, start:end] # (batch_size, block_len)
|
| 990 |
+
block_confidence = confidence[:, start:end] # (batch_size, block_len)
|
| 991 |
+
|
| 992 |
+
for b in range(batch_size):
|
| 993 |
+
# Get positions that are masked in this block for this batch
|
| 994 |
+
masked_positions = torch.where(block_masked[b])[0]
|
| 995 |
+
num_masked = len(masked_positions)
|
| 996 |
+
|
| 997 |
+
if num_masked > 0:
|
| 998 |
+
# Number of tokens to keep (top by confidence)
|
| 999 |
+
k = min(max(1, int(block_size * self.config.ada_perm_ratio_per_block)), num_masked)
|
| 1000 |
+
|
| 1001 |
+
# Get confidence values for masked positions
|
| 1002 |
+
masked_confidence = block_confidence[b, masked_positions]
|
| 1003 |
+
|
| 1004 |
+
# Get indices of top-k confident tokens
|
| 1005 |
+
_, topk_indices = torch.topk(masked_confidence, k)
|
| 1006 |
+
selected_positions = masked_positions[topk_indices]
|
| 1007 |
+
|
| 1008 |
+
# Mark these positions in the selected mask
|
| 1009 |
+
selected_mask[b, start + selected_positions] = True
|
| 1010 |
+
|
| 1011 |
+
# Calculate loss only for selected positions
|
| 1012 |
+
token_loss = torch.nn.functional.cross_entropy(
|
| 1013 |
+
logits[selected_mask],
|
| 1014 |
+
labels[selected_mask],
|
| 1015 |
+
reduction='none'
|
| 1016 |
+
) / p_mask[selected_mask]
|
| 1017 |
+
|
| 1018 |
+
num_mask_tokens = selected_mask.sum()
|
| 1019 |
+
|
| 1020 |
+
elif getattr(self.config, 'complementary_mask', False):
|
| 1021 |
+
token_loss = self.mdm_loss_function(
|
| 1022 |
+
logits=logits[masked_indices],
|
| 1023 |
+
labels=torch.cat([labels, labels], dim=0)[masked_indices],
|
| 1024 |
+
vocab_size=self.config.vocab_size,
|
| 1025 |
+
**kwargs
|
| 1026 |
+
)
|
| 1027 |
+
num_mask_tokens = masked_indices.sum()
|
| 1028 |
+
|
| 1029 |
+
else:
|
| 1030 |
+
# Calculate token-wise cross entropy loss for masked positions in B
|
| 1031 |
+
token_loss = torch.nn.functional.cross_entropy(
|
| 1032 |
+
logits[masked_indices],
|
| 1033 |
+
labels[masked_indices],
|
| 1034 |
+
reduction='none'
|
| 1035 |
+
) / p_mask[masked_indices]
|
| 1036 |
+
|
| 1037 |
+
num_mask_tokens = masked_indices.sum()
|
| 1038 |
+
|
| 1039 |
+
if self.config.global_loss_avg:
|
| 1040 |
+
loss = token_loss.sum()
|
| 1041 |
+
else:
|
| 1042 |
+
loss = token_loss.sum() / num_mask_tokens
|
| 1043 |
+
|
| 1044 |
+
if self.config.ada_dlm_loss_ratio is not None:
|
| 1045 |
+
assert self.current_iter_ratio is not None
|
| 1046 |
+
assert self.config.dlm_loss_weight is not None
|
| 1047 |
+
|
| 1048 |
+
dlm_loss_weight = min(self.config.dlm_loss_weight, self.current_iter_ratio / self.config.ada_dlm_loss_ratio * self.config.dlm_loss_weight)
|
| 1049 |
+
loss = dlm_loss_weight * loss
|
| 1050 |
+
elif self.config.dlm_loss_weight is not None:
|
| 1051 |
+
loss = self.config.dlm_loss_weight * loss
|
| 1052 |
+
|
| 1053 |
+
if self.config.dlm_paradigm == 'sbd_block_diff':
|
| 1054 |
+
|
| 1055 |
+
if getattr(self.config, 'complementary_mask', False):
|
| 1056 |
+
ar_loss = self.causal_loss_function(
|
| 1057 |
+
logits=causal_logits[:logits.shape[0] // 2, :],
|
| 1058 |
+
labels=labels,
|
| 1059 |
+
vocab_size=self.config.vocab_size,
|
| 1060 |
+
**kwargs
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
_diff_val = loss.detach().item()
|
| 1064 |
+
_ar_val = ar_loss.detach().item()
|
| 1065 |
+
if not hasattr(self, '_loss_accum_count'):
|
| 1066 |
+
self._loss_diff_accum = 0.0
|
| 1067 |
+
self._loss_ar_accum = 0.0
|
| 1068 |
+
self._loss_accum_count = 0
|
| 1069 |
+
self._loss_diff_accum += _diff_val
|
| 1070 |
+
self._loss_ar_accum += _ar_val
|
| 1071 |
+
self._loss_accum_count += 1
|
| 1072 |
+
self.loss_diffusion = self._loss_diff_accum
|
| 1073 |
+
self.loss_ar = self._loss_ar_accum
|
| 1074 |
+
|
| 1075 |
+
loss = loss + ar_loss
|
| 1076 |
+
else:
|
| 1077 |
+
causal_logits = causal_logits[..., :-1, :].contiguous()
|
| 1078 |
+
causal_logits = causal_logits.view(-1, causal_logits.size(-1))
|
| 1079 |
+
|
| 1080 |
+
if hasattr(self.config, 'dlm_type') and self.config.dlm_type == 'dream':
|
| 1081 |
+
causal_labels = labels.view(-1)
|
| 1082 |
+
else:
|
| 1083 |
+
causal_labels = labels[..., 1:].contiguous().view(-1)
|
| 1084 |
+
|
| 1085 |
+
if self.config.global_loss_avg:
|
| 1086 |
+
loss_fct = CrossEntropyLoss(reduction='sum')
|
| 1087 |
+
ar_loss = loss_fct(causal_logits, causal_labels)
|
| 1088 |
+
|
| 1089 |
+
self.loss_diffusion = loss.detach().item() / num_mask_tokens
|
| 1090 |
+
self.loss_ar = ar_loss.detach().item() / seq_len
|
| 1091 |
+
|
| 1092 |
+
loss = loss + self.config.ar_loss_weight * ar_loss
|
| 1093 |
+
|
| 1094 |
+
else:
|
| 1095 |
+
loss_fct = CrossEntropyLoss()
|
| 1096 |
+
ar_loss = loss_fct(causal_logits, causal_labels)
|
| 1097 |
+
|
| 1098 |
+
self.loss_diffusion = loss.detach().item()
|
| 1099 |
+
self.loss_ar = ar_loss.detach().item()
|
| 1100 |
+
|
| 1101 |
+
loss = loss + self.config.ar_loss_weight * ar_loss
|
| 1102 |
+
|
| 1103 |
+
# if self.config.global_loss_avg:
|
| 1104 |
+
# if self.config.dlm_paradigm == 'sbd_block_diff':
|
| 1105 |
+
# loss = (loss, num_mask_tokens + int(self.config.ar_loss_weight * seq_len))
|
| 1106 |
+
# else:
|
| 1107 |
+
# loss = (loss, num_mask_tokens)
|
| 1108 |
+
|
| 1109 |
+
return NemotronLabsDiffusionVLMOutputWithPast(
|
| 1110 |
+
loss=loss if not is_teacher else logits,
|
| 1111 |
+
logits=logits,
|
| 1112 |
+
causal_logits=causal_logits,
|
| 1113 |
+
past_key_values=enc_out.past_key_values,
|
| 1114 |
+
hidden_states=None,
|
| 1115 |
+
attentions=None,
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold,
|
| 1120 |
+
causal_context=True, temperature=0, pixel_values=None, image_sizes=None, eos_token_id=None):
|
| 1121 |
+
out_ids, nfe = generate_with_prefix_cache_block_diff(
|
| 1122 |
+
model=self,
|
| 1123 |
+
prompt=prompt_ids,
|
| 1124 |
+
gen_length=max_new_tokens,
|
| 1125 |
+
steps=steps,
|
| 1126 |
+
block_length=block_length,
|
| 1127 |
+
remasking="low_confidence",
|
| 1128 |
+
temperature=temperature,
|
| 1129 |
+
mask_id=self.mask_token_id,
|
| 1130 |
+
threshold=threshold,
|
| 1131 |
+
shift_logits=shift_logits,
|
| 1132 |
+
neg_entropy=False,
|
| 1133 |
+
causal_context=causal_context,
|
| 1134 |
+
pixel_values=pixel_values,
|
| 1135 |
+
image_sizes=image_sizes,
|
| 1136 |
+
eos_token_id=eos_token_id,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
return out_ids, nfe
|
| 1140 |
+
|
| 1141 |
+
@torch.no_grad()
|
| 1142 |
+
def sbd_inference_diffusion_quadratic(
|
| 1143 |
+
self,
|
| 1144 |
+
clean_input_ids: Optional[torch.Tensor],
|
| 1145 |
+
draft_input_ids: torch.Tensor,
|
| 1146 |
+
block_length: int,
|
| 1147 |
+
draft_only: bool = False,
|
| 1148 |
+
past_key_values: Optional[Cache] = None,
|
| 1149 |
+
use_cache: bool = False,
|
| 1150 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1151 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 1152 |
+
):
|
| 1153 |
+
enc_config = self.encoder.config
|
| 1154 |
+
enc_config.use_sbd_objective = True
|
| 1155 |
+
enc_config.block_length = block_length
|
| 1156 |
+
|
| 1157 |
+
if draft_only:
|
| 1158 |
+
assert clean_input_ids is not None
|
| 1159 |
+
|
| 1160 |
+
if use_cache and past_key_values is None:
|
| 1161 |
+
past_key_values = DynamicCache()
|
| 1162 |
+
|
| 1163 |
+
enc_config.self_spec_inference_mode = "default"
|
| 1164 |
+
input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
|
| 1165 |
+
if pixel_values is not None and image_sizes is not None:
|
| 1166 |
+
inputs_embeds = self._embed_with_vision(input_ids, pixel_values, image_sizes)
|
| 1167 |
+
outputs = self.encoder(
|
| 1168 |
+
inputs_embeds=inputs_embeds,
|
| 1169 |
+
position_ids=None,
|
| 1170 |
+
past_key_values=past_key_values,
|
| 1171 |
+
use_cache=use_cache,
|
| 1172 |
+
is_training=False,
|
| 1173 |
+
)
|
| 1174 |
+
else:
|
| 1175 |
+
outputs = self.encoder(
|
| 1176 |
+
input_ids=input_ids,
|
| 1177 |
+
position_ids=None,
|
| 1178 |
+
past_key_values=past_key_values,
|
| 1179 |
+
use_cache=use_cache,
|
| 1180 |
+
is_training=False,
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
hidden_states = outputs.last_hidden_state
|
| 1184 |
+
logits = self.diffusion_head(hidden_states)
|
| 1185 |
+
|
| 1186 |
+
past_key_values = getattr(outputs, "past_key_values", None)
|
| 1187 |
+
if use_cache and past_key_values is not None:
|
| 1188 |
+
_crop_dynamic_cache(past_key_values, clean_input_ids.shape[1])
|
| 1189 |
+
|
| 1190 |
+
return logits, past_key_values
|
| 1191 |
+
else:
|
| 1192 |
+
enc_config.self_spec_inference_mode = "quadratic"
|
| 1193 |
+
|
| 1194 |
+
draft_len = block_length * (block_length + 1)
|
| 1195 |
+
draft_input_ids = torch.cat(
|
| 1196 |
+
[
|
| 1197 |
+
draft_input_ids.view(-1, block_length, 1),
|
| 1198 |
+
torch.full(
|
| 1199 |
+
(draft_input_ids.shape[0], block_length, block_length),
|
| 1200 |
+
fill_value=self.config.mask_token_id,
|
| 1201 |
+
device=draft_input_ids.device,
|
| 1202 |
+
),
|
| 1203 |
+
],
|
| 1204 |
+
dim=-1,
|
| 1205 |
+
).view(-1, draft_len)
|
| 1206 |
+
|
| 1207 |
+
if use_cache:
|
| 1208 |
+
assert past_key_values is not None, (
|
| 1209 |
+
"Past key values should be provided when using cache, e.g. run draft_only=True first."
|
| 1210 |
+
)
|
| 1211 |
+
assert clean_input_ids is None, (
|
| 1212 |
+
"Clean input ids should already be in cache, thus none should be provided."
|
| 1213 |
+
)
|
| 1214 |
+
clean_len = past_key_values.get_seq_length()
|
| 1215 |
+
input_ids = draft_input_ids
|
| 1216 |
+
else:
|
| 1217 |
+
clean_len = clean_input_ids.shape[1]
|
| 1218 |
+
input_ids = torch.cat([clean_input_ids, draft_input_ids], dim=-1)
|
| 1219 |
+
|
| 1220 |
+
per_block_position_ids = torch.arange(
|
| 1221 |
+
clean_len, clean_len + block_length + 1, device=draft_input_ids.device
|
| 1222 |
+
)[None,].repeat(block_length, 1)
|
| 1223 |
+
per_block_position_ids += torch.arange(block_length, device=draft_input_ids.device).view(-1, 1)
|
| 1224 |
+
|
| 1225 |
+
if use_cache:
|
| 1226 |
+
position_ids = per_block_position_ids.view(-1)[None,]
|
| 1227 |
+
else:
|
| 1228 |
+
clean_position_ids = torch.arange(clean_len, device=draft_input_ids.device)
|
| 1229 |
+
position_ids = torch.cat([clean_position_ids, per_block_position_ids.view(-1)], dim=-1)[None,]
|
| 1230 |
+
|
| 1231 |
+
if pixel_values is not None and image_sizes is not None and not use_cache:
|
| 1232 |
+
inputs_embeds = self._embed_with_vision(input_ids, pixel_values, image_sizes)
|
| 1233 |
+
outputs = self.encoder(
|
| 1234 |
+
inputs_embeds=inputs_embeds,
|
| 1235 |
+
position_ids=position_ids,
|
| 1236 |
+
past_key_values=past_key_values,
|
| 1237 |
+
use_cache=use_cache,
|
| 1238 |
+
is_training=False,
|
| 1239 |
+
)
|
| 1240 |
+
else:
|
| 1241 |
+
outputs = self.encoder(
|
| 1242 |
+
input_ids=input_ids,
|
| 1243 |
+
position_ids=position_ids,
|
| 1244 |
+
past_key_values=past_key_values,
|
| 1245 |
+
use_cache=use_cache,
|
| 1246 |
+
is_training=False,
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
hidden_states = outputs.last_hidden_state
|
| 1250 |
+
logits = self.diffusion_head(hidden_states)
|
| 1251 |
+
past_key_values = getattr(outputs, "past_key_values", None)
|
| 1252 |
+
|
| 1253 |
+
if use_cache and past_key_values is not None:
|
| 1254 |
+
_extract_draft_kv_cache(past_key_values, clean_len, block_length)
|
| 1255 |
+
|
| 1256 |
+
return logits, past_key_values
|
| 1257 |
+
|
| 1258 |
+
@torch.no_grad()
|
| 1259 |
+
def self_spec_generate(
|
| 1260 |
+
self,
|
| 1261 |
+
prompt_ids: torch.Tensor,
|
| 1262 |
+
max_new_tokens: int = 128,
|
| 1263 |
+
steps: int = 128,
|
| 1264 |
+
block_length: int = 16,
|
| 1265 |
+
ar_mix_weight: Optional[float] = None,
|
| 1266 |
+
temperature: float = 0.0,
|
| 1267 |
+
mask_token_id: Optional[int] = None,
|
| 1268 |
+
eos_token_id: Optional[int] = None,
|
| 1269 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1270 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 1271 |
+
):
|
| 1272 |
+
self.config.use_sbd_objective = True
|
| 1273 |
+
self.config.dlm_paradigm = "sbd"
|
| 1274 |
+
|
| 1275 |
+
if prompt_ids.shape[0] != 1:
|
| 1276 |
+
raise ValueError("Self speculation quadratic decoding currently requires batch_size == 1")
|
| 1277 |
+
|
| 1278 |
+
token_mask_id = mask_token_id if mask_token_id is not None else self.config.mask_token_id
|
| 1279 |
+
if eos_token_id is None:
|
| 1280 |
+
eos_token_id = getattr(self.config, "eos_token_id", None)
|
| 1281 |
+
|
| 1282 |
+
x = torch.full(
|
| 1283 |
+
(1, prompt_ids.shape[1] + max_new_tokens + block_length * 2),
|
| 1284 |
+
token_mask_id,
|
| 1285 |
+
dtype=torch.long,
|
| 1286 |
+
device=prompt_ids.device,
|
| 1287 |
+
)
|
| 1288 |
+
x[:, : prompt_ids.shape[1]] = prompt_ids.clone()
|
| 1289 |
+
|
| 1290 |
+
if max_new_tokens % block_length != 0:
|
| 1291 |
+
raise ValueError("max_new_tokens must be divisible by block_length")
|
| 1292 |
+
num_blocks = max_new_tokens // block_length
|
| 1293 |
+
if steps % num_blocks != 0:
|
| 1294 |
+
raise ValueError("steps must be divisible by (max_new_tokens // block_length)")
|
| 1295 |
+
|
| 1296 |
+
prompt_len = prompt_ids.shape[1]
|
| 1297 |
+
nfe = 0
|
| 1298 |
+
nfe += 1
|
| 1299 |
+
logits, past_key_values = self.sbd_inference_diffusion_quadratic(
|
| 1300 |
+
clean_input_ids=x[:, :prompt_len],
|
| 1301 |
+
draft_input_ids=x[:, prompt_len : prompt_len + block_length],
|
| 1302 |
+
block_length=block_length,
|
| 1303 |
+
draft_only=True,
|
| 1304 |
+
use_cache=True,
|
| 1305 |
+
pixel_values=pixel_values,
|
| 1306 |
+
image_sizes=image_sizes,
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
logits_proposal = logits[:, prompt_len - 1 : prompt_len + block_length]
|
| 1310 |
+
logits_proposal[:, 1] = logits_proposal[:, 0]
|
| 1311 |
+
logits_proposal = logits_proposal[:, 1:]
|
| 1312 |
+
x0_proposal = torch.argmax(logits_proposal, dim=-1)
|
| 1313 |
+
x[:, prompt_len : prompt_len + block_length] = x0_proposal
|
| 1314 |
+
|
| 1315 |
+
total_accept_token = 0
|
| 1316 |
+
while True:
|
| 1317 |
+
nfe += 1
|
| 1318 |
+
block_start = prompt_len + total_accept_token
|
| 1319 |
+
block_end = block_start + block_length
|
| 1320 |
+
draft_input_ids = x[:, block_start:block_end]
|
| 1321 |
+
|
| 1322 |
+
logits, past_key_values = self.sbd_inference_diffusion_quadratic(
|
| 1323 |
+
clean_input_ids=None,
|
| 1324 |
+
draft_input_ids=draft_input_ids,
|
| 1325 |
+
block_length=block_length,
|
| 1326 |
+
draft_only=False,
|
| 1327 |
+
past_key_values=past_key_values,
|
| 1328 |
+
use_cache=True,
|
| 1329 |
+
pixel_values=pixel_values,
|
| 1330 |
+
image_sizes=image_sizes,
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
useful_token_logits = logits.view(1, block_length, block_length + 1, -1)
|
| 1334 |
+
if ar_mix_weight is None:
|
| 1335 |
+
useful_token_logits[:, :, 1] = useful_token_logits[:, :, 0]
|
| 1336 |
+
else:
|
| 1337 |
+
if not (0.0 <= ar_mix_weight <= 1.0):
|
| 1338 |
+
raise ValueError("ar_mix_weight must be between 0 and 1")
|
| 1339 |
+
mix_logits = useful_token_logits[:, :, 0] * ar_mix_weight + useful_token_logits[:, :, 1] * (1 - ar_mix_weight)
|
| 1340 |
+
useful_token_logits[:, :, 0] = mix_logits
|
| 1341 |
+
useful_token_logits[:, :, 1] = mix_logits
|
| 1342 |
+
|
| 1343 |
+
if temperature > 0:
|
| 1344 |
+
useful_token_logits = useful_token_logits / temperature
|
| 1345 |
+
|
| 1346 |
+
useful_token_pred = torch.argmax(useful_token_logits, dim=-1)
|
| 1347 |
+
new_draft_input_ids = useful_token_pred[:, 0, 1:]
|
| 1348 |
+
accept_cnt = 1
|
| 1349 |
+
|
| 1350 |
+
while accept_cnt < block_length:
|
| 1351 |
+
if useful_token_pred[:, accept_cnt - 1, 0].item() != draft_input_ids[:, accept_cnt].item():
|
| 1352 |
+
break
|
| 1353 |
+
new_draft_input_ids = useful_token_pred[:, accept_cnt, 1:]
|
| 1354 |
+
accept_cnt += 1
|
| 1355 |
+
|
| 1356 |
+
x[:, block_start : block_start + accept_cnt] = draft_input_ids[:, :accept_cnt]
|
| 1357 |
+
|
| 1358 |
+
# EoS early stopping
|
| 1359 |
+
if eos_token_id is not None:
|
| 1360 |
+
accepted = x[0, block_start : block_start + accept_cnt]
|
| 1361 |
+
eos_positions = (accepted == eos_token_id).nonzero(as_tuple=True)[0]
|
| 1362 |
+
if len(eos_positions) > 0:
|
| 1363 |
+
first_eos_rel = eos_positions[0].item()
|
| 1364 |
+
total_accept_token += first_eos_rel + 1
|
| 1365 |
+
output_end = prompt_len + total_accept_token
|
| 1366 |
+
return x[:, :output_end], nfe
|
| 1367 |
+
|
| 1368 |
+
x[:, block_start + accept_cnt : block_start + accept_cnt + block_length] = new_draft_input_ids
|
| 1369 |
+
_crop_dynamic_cache(past_key_values, block_start + accept_cnt)
|
| 1370 |
+
total_accept_token += accept_cnt
|
| 1371 |
+
|
| 1372 |
+
if total_accept_token >= max_new_tokens:
|
| 1373 |
+
break
|
| 1374 |
+
|
| 1375 |
+
return x[:, : -(block_length * 2)], nfe
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
__all__ = ["NemotronLabsDiffusionVLMModel", "NemotronLabsDiffusionVLMFlexAttention"]
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"|<MASK>|"
|
| 4 |
+
],
|
| 5 |
+
"bos_token": {
|
| 6 |
+
"content": "<s>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"eos_token": {
|
| 13 |
+
"content": "<|im_end|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false
|
| 18 |
+
},
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<|im_end|>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenization_nemotron_labs_diffusion_vlm.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom tokenizer for Nemotron-Diffusion-Exp-Ministral-8B-Instruct (final-template).
|
| 3 |
+
|
| 4 |
+
Extends PreTrainedTokenizerFast with a `process_messages` method that
|
| 5 |
+
handles image token expansion and pixel value preprocessing, analogous
|
| 6 |
+
to MistralCommonBackend.apply_chat_template(return_dict=True).
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
|
| 10 |
+
result = tokenizer.process_messages(messages)
|
| 11 |
+
# result["input_ids"] – (1, seq_len) with expanded image tokens
|
| 12 |
+
# result["pixel_values"] – (N, 3, H, W) if images present
|
| 13 |
+
# result["image_sizes"] – list of (H, W) tuples
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, List
|
| 17 |
+
|
| 18 |
+
from transformers import PreTrainedTokenizerFast
|
| 19 |
+
|
| 20 |
+
from .image_processing import process_messages as _process_messages
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class NemotronLabsDiffusionVLMTokenizerFast(PreTrainedTokenizerFast):
|
| 24 |
+
"""PreTrainedTokenizerFast + image-aware process_messages()."""
|
| 25 |
+
|
| 26 |
+
def process_messages(
|
| 27 |
+
self,
|
| 28 |
+
messages: List[Dict[str, Any]],
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> Dict[str, Any]:
|
| 31 |
+
"""
|
| 32 |
+
Process chat messages with optional images.
|
| 33 |
+
|
| 34 |
+
Renders the chat template, expands image placeholders based on
|
| 35 |
+
actual image dimensions, preprocesses pixel values, and tokenizes.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
messages: OpenAI-style list of message dicts.
|
| 39 |
+
**kwargs: forwarded to image_processing.process_messages
|
| 40 |
+
(patch_size, spatial_merge_size, max_image_size,
|
| 41 |
+
return_tensors, enable_thinking).
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
dict with input_ids, and optionally pixel_values + image_sizes.
|
| 45 |
+
"""
|
| 46 |
+
return _process_messages(self, messages, **kwargs)
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e04613060199ab156b35bf2334e381748ae41311e8785efb330bc66e16670d8
|
| 3 |
+
size 17077689
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|