Upload quantized model HRM-Text-1B-autoround-MXFP4
Browse files- README.md +118 -0
- config.json +61 -0
- configuration_hrm_text.py +146 -0
- generation_config.json +8 -0
- model.safetensors +3 -0
- modeling_hrm_text.py +644 -0
- quantization_config.json +16 -0
- tokenizer.json +0 -0
- tokenizer_config.json +12 -0
README.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model:
|
| 3 |
+
- sapientinc/HRM-Text-1B
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- quantized
|
| 7 |
+
- mxfp4
|
| 8 |
+
- autoround
|
| 9 |
+
- low-bit-open-llm-leaderboard
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# HRM-Text-1B-autoround-MXFP4
|
| 13 |
+
|
| 14 |
+
## Model Details
|
| 15 |
+
|
| 16 |
+
This model is a MXFP4 (Microscaling FP4) quantization of [sapientinc/HRM-Text-1B](https://huggingface.co/sapientinc/HRM-Text-1B) generated by [AutoRound](https://github.com/intel/auto-round). Please follow the license of the original model.
|
| 17 |
+
|
| 18 |
+
## Quantization Details
|
| 19 |
+
|
| 20 |
+
| Attribute | Value |
|
| 21 |
+
|-----------|-------|
|
| 22 |
+
| Base Model | [sapientinc/HRM-Text-1B](https://huggingface.co/sapientinc/HRM-Text-1B) |
|
| 23 |
+
| Quantization Tool | [AutoRound](https://github.com/intel/auto-round) |
|
| 24 |
+
| Quantization Scheme | MXFP4 |
|
| 25 |
+
| Original Size | 2256 MB |
|
| 26 |
+
| Quantized Size | 886 MB |
|
| 27 |
+
|
| 28 |
+
## Evaluation Results
|
| 29 |
+
|
| 30 |
+
| Task | Accuracy |
|
| 31 |
+
|------|----------|
|
| 32 |
+
| hellaswag | 0.2504 |
|
| 33 |
+
| mmlu | 0.2309 |
|
| 34 |
+
| piqa | 0.4951 |
|
| 35 |
+
|
| 36 |
+
## How to Use
|
| 37 |
+
|
| 38 |
+
### HF Usage
|
| 39 |
+
|
| 40 |
+
**Step 1: Install [AutoRound](https://github.com/intel/auto-round)**
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
pip install auto-round
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
**Step 2: Load and run the quantized model**
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 50 |
+
|
| 51 |
+
model_name = "HRM-Text-1B-autoround-MXFP4"
|
| 52 |
+
|
| 53 |
+
# load the tokenizer and the model
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 55 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
| 56 |
+
|
| 57 |
+
# prepare the model input
|
| 58 |
+
prompt = "Write a quick sort algorithm."
|
| 59 |
+
messages = [{"role": "user", "content": prompt}]
|
| 60 |
+
text = tokenizer.apply_chat_template(
|
| 61 |
+
messages,
|
| 62 |
+
tokenize=False,
|
| 63 |
+
add_generation_prompt=True,
|
| 64 |
+
)
|
| 65 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 66 |
+
|
| 67 |
+
# conduct text completion
|
| 68 |
+
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
| 69 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
| 70 |
+
|
| 71 |
+
content = tokenizer.decode(output_ids, skip_special_tokens=True)
|
| 72 |
+
print("content:", content)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### VLLM Usage
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
vllm serve HRM-Text-1B-autoround-MXFP4 \
|
| 79 |
+
--trust-remote-code \
|
| 80 |
+
--dtype bfloat16 \
|
| 81 |
+
--tensor_parallel_size 1
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
If you encounter any issues, feel free to open an issue on the [AutoRound GitHub repo](https://github.com/intel/auto-round/issues) or provide feedback on the [Low-Bit Open LLM Leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard).
|
| 85 |
+
|
| 86 |
+
## Ethical Considerations and Limitations
|
| 87 |
+
|
| 88 |
+
The model can produce factually incorrect output, and should not be relied on to produce factually accurate information. Because of the limitations of the pretrained model and the finetuning datasets, it is possible that this model could generate lewd, biased or otherwise offensive outputs.
|
| 89 |
+
Therefore, before deploying any applications of the model, developers should perform safety testing.
|
| 90 |
+
|
| 91 |
+
## Caveats and Recommendations
|
| 92 |
+
|
| 93 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model.
|
| 94 |
+
Here are a couple of useful links to learn more about Intel's AI software:
|
| 95 |
+
|
| 96 |
+
- [Intel Neural Compressor](https://github.com/intel/neural-compressor)
|
| 97 |
+
- [AutoRound](https://github.com/intel/auto-round)
|
| 98 |
+
|
| 99 |
+
## Disclaimer
|
| 100 |
+
|
| 101 |
+
The license on this model does not constitute legal advice. We are not responsible for the actions of third parties who use this model. Please consult an attorney before using this model for commercial purposes.
|
| 102 |
+
|
| 103 |
+
## Cite
|
| 104 |
+
|
| 105 |
+
```
|
| 106 |
+
@article{cheng2023optimize,
|
| 107 |
+
title={Optimize weight rounding via signed gradient descent for the quantization of llms},
|
| 108 |
+
author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi},
|
| 109 |
+
journal={arXiv preprint arXiv:2309.05516},
|
| 110 |
+
year={2023}
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
[arxiv](https://arxiv.org/abs/2309.05516) [github](https://github.com/intel/auto-round)
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
*This model is part of the [Intel Low-Bit Open LLM Leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) initiative.*
|
config.json
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"H_cycles": 2,
|
| 3 |
+
"L_bp_cycles": [
|
| 4 |
+
0,
|
| 5 |
+
3
|
| 6 |
+
],
|
| 7 |
+
"L_cycles": 3,
|
| 8 |
+
"architectures": [
|
| 9 |
+
"HrmTextForCausalLM"
|
| 10 |
+
],
|
| 11 |
+
"attention_bias": false,
|
| 12 |
+
"attention_dropout": 0.0,
|
| 13 |
+
"auto_map": {
|
| 14 |
+
"AutoConfig": "configuration_hrm_text.HrmTextConfig",
|
| 15 |
+
"AutoModel": "modeling_hrm_text.HrmTextModel",
|
| 16 |
+
"AutoModelForCausalLM": "modeling_hrm_text.HrmTextForCausalLM"
|
| 17 |
+
},
|
| 18 |
+
"bos_token_id": 6,
|
| 19 |
+
"dtype": "bfloat16",
|
| 20 |
+
"embedding_scale": 39.191835884530846,
|
| 21 |
+
"eos_token_id": 11,
|
| 22 |
+
"head_dim": 128,
|
| 23 |
+
"hidden_act": "silu",
|
| 24 |
+
"hidden_size": 1536,
|
| 25 |
+
"initializer_range": 0.025515518153991442,
|
| 26 |
+
"intermediate_size": 4096,
|
| 27 |
+
"max_position_embeddings": 4096,
|
| 28 |
+
"mlp_bias": false,
|
| 29 |
+
"model_type": "hrm_text",
|
| 30 |
+
"num_attention_heads": 12,
|
| 31 |
+
"num_hidden_layers": 128,
|
| 32 |
+
"num_key_value_heads": 12,
|
| 33 |
+
"num_layers_per_stack": 16,
|
| 34 |
+
"pad_token_id": 5,
|
| 35 |
+
"prefix_lm": true,
|
| 36 |
+
"quantization_config": {
|
| 37 |
+
"act_bits": 4,
|
| 38 |
+
"act_data_type": "mx_fp",
|
| 39 |
+
"act_dynamic": true,
|
| 40 |
+
"act_group_size": 32,
|
| 41 |
+
"act_sym": true,
|
| 42 |
+
"autoround_version": "0.12.3",
|
| 43 |
+
"bits": 4,
|
| 44 |
+
"data_type": "mx_fp",
|
| 45 |
+
"group_size": 32,
|
| 46 |
+
"iters": 0,
|
| 47 |
+
"low_gpu_mem_usage": true,
|
| 48 |
+
"packing_format": "auto_round:llm_compressor",
|
| 49 |
+
"quant_method": "auto-round",
|
| 50 |
+
"sym": true
|
| 51 |
+
},
|
| 52 |
+
"rms_norm_eps": 1e-06,
|
| 53 |
+
"rope_parameters": {
|
| 54 |
+
"rope_theta": 10000.0,
|
| 55 |
+
"rope_type": "default"
|
| 56 |
+
},
|
| 57 |
+
"tie_word_embeddings": false,
|
| 58 |
+
"transformers_version": "5.9.0",
|
| 59 |
+
"use_cache": true,
|
| 60 |
+
"vocab_size": 65536
|
| 61 |
+
}
|
configuration_hrm_text.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_hrm_text.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
from huggingface_hub.dataclasses import strict
|
| 22 |
+
|
| 23 |
+
from transformers.configuration_utils import PreTrainedConfig
|
| 24 |
+
from transformers.modeling_rope_utils import RopeParameters
|
| 25 |
+
from transformers.utils import auto_docstring
|
| 26 |
+
from transformers.utils.generic import is_flash_attention_requested, split_attention_implementation
|
| 27 |
+
from transformers.utils.type_validators import interval
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@auto_docstring(checkpoint="sapientinc/HRM-Text-1B")
|
| 31 |
+
@strict
|
| 32 |
+
class HrmTextConfig(PreTrainedConfig):
|
| 33 |
+
r"""
|
| 34 |
+
H_cycles (`int`, *optional*, defaults to 2):
|
| 35 |
+
Number of high-level cycles.
|
| 36 |
+
L_cycles (`int`, *optional*, defaults to 3):
|
| 37 |
+
Number of low-level cycles per H-cycle.
|
| 38 |
+
L_bp_cycles (`list[int]`, *optional*, defaults to `[2]`):
|
| 39 |
+
Training-time gradient-routing list; left-padded with `1`s up to `L_cycles` inside the model.
|
| 40 |
+
Inference-time no-op.
|
| 41 |
+
embedding_scale (`float`, *optional*):
|
| 42 |
+
Token-embedding multiplier. If `None`, defaults to `1 / initializer_range`.
|
| 43 |
+
prefix_lm (`bool`, *optional*, defaults to `True`):
|
| 44 |
+
Instruction tokens attend bidirectionally, response tokens attend causally.
|
| 45 |
+
num_layers_per_stack (`int`, *optional*):
|
| 46 |
+
Real number of transformer blocks inside each
|
| 47 |
+
of the H / L stacks. Set automatically on first construction: the value passed as
|
| 48 |
+
`num_hidden_layers` is remembered here and `num_hidden_layers` is then rewritten to
|
| 49 |
+
`num_layers_per_stack * H_cycles * (L_cycles + 1)` so that
|
| 50 |
+
`DynamicCache(config=...)` pre-allocates one slot per unique attention invocation
|
| 51 |
+
under the recurrent forward. Do not set this directly on first construction — pass
|
| 52 |
+
the real per-stack count as `num_hidden_layers` and let `__post_init__` split it.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
model_type = "hrm_text"
|
| 56 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 57 |
+
|
| 58 |
+
base_model_tp_plan = {
|
| 59 |
+
**{f"{stack}.layers.*.self_attn.q_proj": "colwise" for stack in ("L_module", "H_module")},
|
| 60 |
+
**{f"{stack}.layers.*.self_attn.k_proj": "colwise" for stack in ("L_module", "H_module")},
|
| 61 |
+
**{f"{stack}.layers.*.self_attn.v_proj": "colwise" for stack in ("L_module", "H_module")},
|
| 62 |
+
**{f"{stack}.layers.*.self_attn.gate_proj": "colwise" for stack in ("L_module", "H_module")},
|
| 63 |
+
**{f"{stack}.layers.*.self_attn.o_proj": "rowwise" for stack in ("L_module", "H_module")},
|
| 64 |
+
**{f"{stack}.layers.*.mlp.gate_proj": "colwise" for stack in ("L_module", "H_module")},
|
| 65 |
+
**{f"{stack}.layers.*.mlp.up_proj": "colwise" for stack in ("L_module", "H_module")},
|
| 66 |
+
**{f"{stack}.layers.*.mlp.down_proj": "rowwise" for stack in ("L_module", "H_module")},
|
| 67 |
+
}
|
| 68 |
+
base_model_pp_plan = {
|
| 69 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 70 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 71 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
vocab_size: int = 151808
|
| 75 |
+
hidden_size: int = 1536
|
| 76 |
+
intermediate_size: int = 4096
|
| 77 |
+
num_hidden_layers: int = 16
|
| 78 |
+
num_attention_heads: int = 12
|
| 79 |
+
hidden_act: str = "silu"
|
| 80 |
+
max_position_embeddings: int = 2048
|
| 81 |
+
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
|
| 82 |
+
rms_norm_eps: float = 1e-6
|
| 83 |
+
use_cache: bool = True
|
| 84 |
+
pad_token_id: int | None = None
|
| 85 |
+
bos_token_id: int | None = None
|
| 86 |
+
eos_token_id: int | list[int] | None = None
|
| 87 |
+
tie_word_embeddings: bool = False
|
| 88 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 89 |
+
attention_bias: bool = False
|
| 90 |
+
attention_dropout: int | float | None = 0.0
|
| 91 |
+
mlp_bias: bool = False
|
| 92 |
+
head_dim: int = 128
|
| 93 |
+
|
| 94 |
+
H_cycles: int = 2
|
| 95 |
+
L_cycles: int = 3
|
| 96 |
+
L_bp_cycles: list[int] | None = None
|
| 97 |
+
embedding_scale: float | None = None
|
| 98 |
+
prefix_lm: bool = True
|
| 99 |
+
num_layers_per_stack: int | None = None # Usually inferred in post init
|
| 100 |
+
|
| 101 |
+
def __post_init__(self, **kwargs):
|
| 102 |
+
if self.L_bp_cycles is None:
|
| 103 |
+
# Default `[2]` = backprop only the last 2 L-iterations per H-cycle (training-time
|
| 104 |
+
# gradient-routing knob). Left-padding to length `L_cycles` is performed inside
|
| 105 |
+
# [`HrmTextModel`] since it depends on `L_cycles`.
|
| 106 |
+
self.L_bp_cycles = [2]
|
| 107 |
+
|
| 108 |
+
if self.embedding_scale is None:
|
| 109 |
+
self.embedding_scale = 1.0 / self.initializer_range
|
| 110 |
+
|
| 111 |
+
if self.num_layers_per_stack is None:
|
| 112 |
+
# Initial construction, or legacy checkpoint where `num_hidden_layers` carries the
|
| 113 |
+
# real per-stack count: remember that value and rewrite `num_hidden_layers` to the
|
| 114 |
+
# inflated total, so standard HF cache allocation gives us one slot per unique
|
| 115 |
+
# attention invocation. Serialised configs round-trip as (inflated, real) pairs.
|
| 116 |
+
self.num_layers_per_stack = self.num_hidden_layers
|
| 117 |
+
self.num_hidden_layers = self.num_layers_per_stack * self.H_cycles * (self.L_cycles + 1)
|
| 118 |
+
|
| 119 |
+
super().__post_init__(**kwargs)
|
| 120 |
+
|
| 121 |
+
def validate_architecture(self):
|
| 122 |
+
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
| 123 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
| 126 |
+
f"heads ({self.num_attention_heads})."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def _attn_implementation(self):
|
| 131 |
+
return self._attn_implementation_internal
|
| 132 |
+
|
| 133 |
+
@_attn_implementation.setter
|
| 134 |
+
def _attn_implementation(self, value: str | dict | None):
|
| 135 |
+
if value is not None and self.prefix_lm:
|
| 136 |
+
_, base_implementation = split_attention_implementation(value)
|
| 137 |
+
if is_flash_attention_requested(requested_attention_implementation=base_implementation):
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"`attn_implementation={value!r}` is not supported when "
|
| 140 |
+
"`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
|
| 141 |
+
"overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
|
| 142 |
+
)
|
| 143 |
+
PreTrainedConfig._attn_implementation.__set__(self, value)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
__all__ = ["HrmTextConfig"]
|
generation_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 6,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": 11,
|
| 6 |
+
"pad_token_id": 5,
|
| 7 |
+
"transformers_version": "5.9.0"
|
| 8 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:905484c08c12cf186265a02a23b6f780366999d2caf2b6614fafd6b684fd2499
|
| 3 |
+
size 924093024
|
modeling_hrm_text.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_hrm_text.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
from collections.abc import Callable
|
| 22 |
+
from contextlib import nullcontext
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn
|
| 27 |
+
|
| 28 |
+
from transformers import initialization as init
|
| 29 |
+
from transformers.activations import ACT2FN
|
| 30 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 31 |
+
from transformers.configuration_utils import PreTrainedConfig
|
| 32 |
+
from transformers.generation import GenerationMixin
|
| 33 |
+
from transformers.integrations import use_kernel_func_from_hub, use_kernelized_func
|
| 34 |
+
from transformers.masking_utils import create_causal_mask, create_masks_for_generate
|
| 35 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 37 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 38 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 39 |
+
from transformers.processing_utils import Unpack
|
| 40 |
+
from transformers.utils import auto_docstring, can_return_tuple, logging
|
| 41 |
+
from transformers.utils.generic import (
|
| 42 |
+
TransformersKwargs,
|
| 43 |
+
is_flash_attention_requested,
|
| 44 |
+
maybe_autocast,
|
| 45 |
+
merge_with_config_defaults,
|
| 46 |
+
split_attention_implementation,
|
| 47 |
+
)
|
| 48 |
+
from transformers.utils.output_capturing import capture_outputs
|
| 49 |
+
from .configuration_hrm_text import HrmTextConfig
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class HrmTextRMSNorm(torch.nn.Module):
|
| 56 |
+
def __init__(self, eps: float = 1e-6):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.eps = eps
|
| 59 |
+
|
| 60 |
+
def _norm(self, x):
|
| 61 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return self._norm(x.float()).type_as(x)
|
| 65 |
+
|
| 66 |
+
def extra_repr(self):
|
| 67 |
+
return f"eps={self.eps}"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class HrmTextMLP(nn.Module):
|
| 71 |
+
def __init__(self, config):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.config = config
|
| 74 |
+
self.hidden_size = config.hidden_size
|
| 75 |
+
self.intermediate_size = config.intermediate_size
|
| 76 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 77 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 78 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 79 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 83 |
+
return down_proj
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def rotate_half(x):
|
| 87 |
+
"""Rotates half the hidden dims of the input."""
|
| 88 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 89 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 90 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
| 94 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 95 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
q (`torch.Tensor`): The query tensor.
|
| 99 |
+
k (`torch.Tensor`): The key tensor.
|
| 100 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 101 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 102 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 103 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 104 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 105 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 106 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 107 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 108 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 109 |
+
Returns:
|
| 110 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 111 |
+
"""
|
| 112 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 113 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 114 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 115 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 116 |
+
return q_embed, k_embed
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 122 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 123 |
+
"""
|
| 124 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 125 |
+
if n_rep == 1:
|
| 126 |
+
return hidden_states
|
| 127 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 128 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def eager_attention_forward(
|
| 132 |
+
module: nn.Module,
|
| 133 |
+
query: torch.Tensor,
|
| 134 |
+
key: torch.Tensor,
|
| 135 |
+
value: torch.Tensor,
|
| 136 |
+
attention_mask: torch.Tensor | None,
|
| 137 |
+
scaling: float,
|
| 138 |
+
dropout: float = 0.0,
|
| 139 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 140 |
+
):
|
| 141 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 142 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 143 |
+
|
| 144 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 145 |
+
if attention_mask is not None:
|
| 146 |
+
attn_weights = attn_weights + attention_mask
|
| 147 |
+
|
| 148 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 149 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 150 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 151 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 152 |
+
|
| 153 |
+
return attn_output, attn_weights
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@use_kernelized_func(apply_rotary_pos_emb)
|
| 157 |
+
class HrmTextAttention(nn.Module):
|
| 158 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, config: HrmTextConfig, layer_idx: int):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
self.layer_idx = layer_idx
|
| 164 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 165 |
+
self.num_key_value_groups = 1 # Uses MHA instead of GQA
|
| 166 |
+
self.scaling = self.head_dim**-0.5
|
| 167 |
+
self.attention_dropout = config.attention_dropout
|
| 168 |
+
self.is_causal = True
|
| 169 |
+
|
| 170 |
+
self.q_proj = nn.Linear(
|
| 171 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 172 |
+
)
|
| 173 |
+
self.k_proj = nn.Linear(
|
| 174 |
+
config.hidden_size,
|
| 175 |
+
config.num_attention_heads * self.head_dim,
|
| 176 |
+
bias=config.attention_bias,
|
| 177 |
+
)
|
| 178 |
+
self.v_proj = nn.Linear(
|
| 179 |
+
config.hidden_size,
|
| 180 |
+
config.num_attention_heads * self.head_dim,
|
| 181 |
+
bias=config.attention_bias,
|
| 182 |
+
)
|
| 183 |
+
self.o_proj = nn.Linear(
|
| 184 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 185 |
+
)
|
| 186 |
+
# Additional sigmoid gate applied at the end
|
| 187 |
+
self.gate_proj = nn.Linear(
|
| 188 |
+
config.hidden_size,
|
| 189 |
+
config.num_attention_heads * self.head_dim,
|
| 190 |
+
bias=config.attention_bias,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
hidden_states: torch.Tensor,
|
| 196 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 197 |
+
attention_mask: torch.Tensor | None = None,
|
| 198 |
+
past_key_values: Cache | None = None,
|
| 199 |
+
cycle_offset: int = 0,
|
| 200 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 201 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 202 |
+
input_shape = hidden_states.shape[:-1]
|
| 203 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 204 |
+
|
| 205 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 206 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 207 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 208 |
+
gate_states = self.gate_proj(hidden_states).view(hidden_shape)
|
| 209 |
+
|
| 210 |
+
cos, sin = position_embeddings
|
| 211 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 212 |
+
|
| 213 |
+
if past_key_values is not None:
|
| 214 |
+
# Adjust cache slot by `cycle_offset` which is determined by it's current recurrent step through the stacks
|
| 215 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx + cycle_offset)
|
| 216 |
+
|
| 217 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 218 |
+
self.config._attn_implementation, eager_attention_forward
|
| 219 |
+
)
|
| 220 |
+
attn_output, attn_weights = attention_interface(
|
| 221 |
+
self,
|
| 222 |
+
query_states,
|
| 223 |
+
key_states,
|
| 224 |
+
value_states,
|
| 225 |
+
attention_mask,
|
| 226 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 227 |
+
scaling=self.scaling,
|
| 228 |
+
**kwargs,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Additional sigmoid gating (similar to Qwen3Next)
|
| 232 |
+
attn_output = torch.sigmoid(gate_states) * attn_output
|
| 233 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 234 |
+
attn_output = self.o_proj(attn_output)
|
| 235 |
+
return attn_output, attn_weights
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class HrmTextDecoderLayer(GradientCheckpointingLayer):
|
| 239 |
+
def __init__(self, config: HrmTextConfig, layer_idx: int):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.hidden_size = config.hidden_size
|
| 242 |
+
|
| 243 |
+
self.self_attn = HrmTextAttention(config=config, layer_idx=layer_idx)
|
| 244 |
+
|
| 245 |
+
self.mlp = HrmTextMLP(config)
|
| 246 |
+
self.input_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
|
| 247 |
+
self.post_attention_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
hidden_states: torch.Tensor,
|
| 252 |
+
attention_mask: torch.Tensor | None = None,
|
| 253 |
+
position_ids: torch.LongTensor | None = None,
|
| 254 |
+
past_key_values: Cache | None = None,
|
| 255 |
+
use_cache: bool | None = False,
|
| 256 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 257 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 258 |
+
) -> torch.Tensor:
|
| 259 |
+
residual = hidden_states
|
| 260 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 261 |
+
# Self Attention
|
| 262 |
+
hidden_states, _ = self.self_attn(
|
| 263 |
+
hidden_states=hidden_states,
|
| 264 |
+
attention_mask=attention_mask,
|
| 265 |
+
position_ids=position_ids,
|
| 266 |
+
past_key_values=past_key_values,
|
| 267 |
+
use_cache=use_cache,
|
| 268 |
+
position_embeddings=position_embeddings,
|
| 269 |
+
**kwargs,
|
| 270 |
+
)
|
| 271 |
+
hidden_states = residual + hidden_states
|
| 272 |
+
|
| 273 |
+
# Fully Connected
|
| 274 |
+
residual = hidden_states
|
| 275 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 276 |
+
hidden_states = self.mlp(hidden_states)
|
| 277 |
+
hidden_states = residual + hidden_states
|
| 278 |
+
return hidden_states
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class HrmTextStack(nn.Module):
|
| 282 |
+
"""A single transformer stack — used twice inside, once as H module and once as L module"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, config: HrmTextConfig):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.layers = nn.ModuleList(
|
| 287 |
+
[HrmTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers_per_stack)]
|
| 288 |
+
)
|
| 289 |
+
self.final_norm = HrmTextRMSNorm(eps=config.rms_norm_eps)
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
hidden_states: torch.Tensor,
|
| 294 |
+
attention_mask: torch.Tensor | None = None,
|
| 295 |
+
past_key_values: Cache | None = None,
|
| 296 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 297 |
+
cycle_offset: int = 0,
|
| 298 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 299 |
+
) -> torch.Tensor:
|
| 300 |
+
for layer in self.layers:
|
| 301 |
+
hidden_states = layer(
|
| 302 |
+
hidden_states,
|
| 303 |
+
attention_mask=attention_mask,
|
| 304 |
+
past_key_values=past_key_values,
|
| 305 |
+
position_embeddings=position_embeddings,
|
| 306 |
+
cycle_offset=cycle_offset,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
return self.final_norm(hidden_states)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
@auto_docstring
|
| 313 |
+
class HrmTextPreTrainedModel(PreTrainedModel):
|
| 314 |
+
config: HrmTextConfig
|
| 315 |
+
base_model_prefix = "model"
|
| 316 |
+
supports_gradient_checkpointing = True
|
| 317 |
+
_no_split_modules = ["HrmTextDecoderLayer"]
|
| 318 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 319 |
+
_supports_flash_attn = True
|
| 320 |
+
_supports_sdpa = True
|
| 321 |
+
_supports_flex_attn = True
|
| 322 |
+
|
| 323 |
+
_can_compile_fullgraph = True
|
| 324 |
+
_supports_attention_backend = True
|
| 325 |
+
_can_record_outputs = {
|
| 326 |
+
"hidden_states": HrmTextDecoderLayer,
|
| 327 |
+
"attentions": HrmTextAttention,
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def _check_and_adjust_attn_implementation(
|
| 331 |
+
self, attn_implementation: str | None, is_init_check: bool = False, allow_all_kernels: bool = False
|
| 332 |
+
) -> str:
|
| 333 |
+
if attn_implementation is not None and self.config.prefix_lm:
|
| 334 |
+
_, base_implementation = split_attention_implementation(attn_implementation)
|
| 335 |
+
if is_flash_attention_requested(requested_attention_implementation=base_implementation):
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"`attn_implementation={attn_implementation!r}` is not supported when "
|
| 338 |
+
"`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
|
| 339 |
+
"overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
|
| 340 |
+
)
|
| 341 |
+
return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check, allow_all_kernels)
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def _init_weights(self, module):
|
| 345 |
+
super()._init_weights(module)
|
| 346 |
+
if isinstance(module, HrmTextModel):
|
| 347 |
+
init.zeros_(module.z_L_init)
|
| 348 |
+
# `z_L_init` is the frozen low-cycle initial state and never trains.
|
| 349 |
+
module.z_L_init.requires_grad_(False) # trf-ignore: TRF012
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class HrmTextRotaryEmbedding(nn.Module):
|
| 353 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 354 |
+
|
| 355 |
+
def __init__(self, config: HrmTextConfig, device=None):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 358 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 359 |
+
|
| 360 |
+
self.config = config
|
| 361 |
+
|
| 362 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 363 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 364 |
+
if self.rope_type != "default":
|
| 365 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 366 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 367 |
+
|
| 368 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 369 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def compute_default_rope_parameters(
|
| 373 |
+
config: HrmTextConfig | None = None,
|
| 374 |
+
device: Optional["torch.device"] = None,
|
| 375 |
+
seq_len: int | None = None,
|
| 376 |
+
) -> tuple["torch.Tensor", float]:
|
| 377 |
+
"""
|
| 378 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 379 |
+
Args:
|
| 380 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 381 |
+
The model configuration.
|
| 382 |
+
device (`torch.device`):
|
| 383 |
+
The device to use for initialization of the inverse frequencies.
|
| 384 |
+
seq_len (`int`, *optional*):
|
| 385 |
+
The current sequence length. Unused for this type of RoPE.
|
| 386 |
+
Returns:
|
| 387 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 388 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 389 |
+
"""
|
| 390 |
+
base = config.rope_parameters["rope_theta"]
|
| 391 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 392 |
+
|
| 393 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 394 |
+
|
| 395 |
+
# Compute the inverse frequencies
|
| 396 |
+
inv_freq = 1.0 / (
|
| 397 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 398 |
+
)
|
| 399 |
+
return inv_freq, attention_factor
|
| 400 |
+
|
| 401 |
+
@torch.no_grad()
|
| 402 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 403 |
+
def forward(self, x, position_ids):
|
| 404 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 405 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 406 |
+
|
| 407 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 408 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 409 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 410 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 411 |
+
cos = emb.cos() * self.attention_scaling
|
| 412 |
+
sin = emb.sin() * self.attention_scaling
|
| 413 |
+
|
| 414 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@auto_docstring
|
| 418 |
+
class HrmTextModel(HrmTextPreTrainedModel):
|
| 419 |
+
def __init__(self, config: HrmTextConfig):
|
| 420 |
+
super().__init__(config)
|
| 421 |
+
self.padding_idx = config.pad_token_id
|
| 422 |
+
self.vocab_size = config.vocab_size
|
| 423 |
+
|
| 424 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 425 |
+
self.rotary_emb = HrmTextRotaryEmbedding(config=config)
|
| 426 |
+
self.gradient_checkpointing = False
|
| 427 |
+
|
| 428 |
+
self.embedding_scale = config.embedding_scale
|
| 429 |
+
|
| 430 |
+
# Recursive module structures
|
| 431 |
+
self.L_module = HrmTextStack(config)
|
| 432 |
+
self.H_module = HrmTextStack(config)
|
| 433 |
+
# Initial state for the low cycle module
|
| 434 |
+
self.z_L_init = nn.Parameter(torch.zeros(config.hidden_size), requires_grad=False)
|
| 435 |
+
|
| 436 |
+
raw_bp = list(config.L_bp_cycles)
|
| 437 |
+
self.L_bp_cycles_padded = [1] * max(0, config.H_cycles - len(raw_bp)) + raw_bp
|
| 438 |
+
|
| 439 |
+
# Initialize weights and apply final processing
|
| 440 |
+
self.post_init()
|
| 441 |
+
|
| 442 |
+
@merge_with_config_defaults
|
| 443 |
+
@capture_outputs
|
| 444 |
+
@auto_docstring
|
| 445 |
+
def forward(
|
| 446 |
+
self,
|
| 447 |
+
input_ids: torch.LongTensor | None = None,
|
| 448 |
+
attention_mask: torch.Tensor | None = None,
|
| 449 |
+
position_ids: torch.LongTensor | None = None,
|
| 450 |
+
past_key_values: Cache | None = None,
|
| 451 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 452 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 453 |
+
use_cache: bool | None = None,
|
| 454 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 455 |
+
) -> BaseModelOutputWithPast:
|
| 456 |
+
r"""
|
| 457 |
+
token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
|
| 458 |
+
Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
|
| 459 |
+
form a single bidirectional block; all other positions are causal.
|
| 460 |
+
"""
|
| 461 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 462 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 463 |
+
|
| 464 |
+
if inputs_embeds is None:
|
| 465 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 466 |
+
# Additional scaling on the input embeds
|
| 467 |
+
inputs_embeds = inputs_embeds * self.embedding_scale
|
| 468 |
+
|
| 469 |
+
if use_cache and past_key_values is None:
|
| 470 |
+
past_key_values = DynamicCache(config=self.config)
|
| 471 |
+
|
| 472 |
+
if position_ids is None:
|
| 473 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 474 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 475 |
+
position_ids = position_ids.unsqueeze(0)
|
| 476 |
+
|
| 477 |
+
# Create mask with optional prefix-based bidirectionality
|
| 478 |
+
mask_kwargs = {
|
| 479 |
+
"config": self.config,
|
| 480 |
+
"inputs_embeds": inputs_embeds,
|
| 481 |
+
"attention_mask": attention_mask,
|
| 482 |
+
"past_key_values": past_key_values,
|
| 483 |
+
"position_ids": position_ids,
|
| 484 |
+
}
|
| 485 |
+
is_first_iteration = past_key_values is None or not past_key_values.is_initialized
|
| 486 |
+
if token_type_ids is not None and is_first_iteration:
|
| 487 |
+
if self.config.prefix_lm:
|
| 488 |
+
mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
|
| 489 |
+
else:
|
| 490 |
+
logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
|
| 491 |
+
|
| 492 |
+
attention_mask = create_causal_mask(**mask_kwargs)
|
| 493 |
+
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
|
| 494 |
+
|
| 495 |
+
# Hierarchical (H/L)-cycle recurrence
|
| 496 |
+
#
|
| 497 |
+
# `z_H` - slow / high-level state
|
| 498 |
+
hidden_states_high_cycle = inputs_embeds
|
| 499 |
+
# `z_L` - fast / low-level state
|
| 500 |
+
hidden_states_low_cycle = (
|
| 501 |
+
self.z_L_init.to(dtype=hidden_states_high_cycle.dtype, device=hidden_states_high_cycle.device)
|
| 502 |
+
.expand_as(hidden_states_high_cycle)
|
| 503 |
+
.contiguous()
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Cache-slot layout under the recurrent forward:
|
| 507 |
+
#
|
| 508 |
+
# slot(h, l, layer) = (h * (L_cycles + 1) + l) * num_layers_per_stack + layer
|
| 509 |
+
# ^— L-stack invocation at (h, l)
|
| 510 |
+
# slot(h, H, layer) = (h * (L_cycles + 1) + L_cycles) * num_layers_per_stack + layer
|
| 511 |
+
# ^— trailing H-stack invocation
|
| 512 |
+
#
|
| 513 |
+
# That totals `num_layers_per_stack * H_cycles * (L_cycles + 1)` slots, i.e. the `config.num_hidden_layers`.
|
| 514 |
+
num_layers_per_stack = self.config.num_layers_per_stack
|
| 515 |
+
for high_cycle_idx in range(self.config.H_cycles):
|
| 516 |
+
# `L_bp_cycles` k-step grad trick: only the trailing `num_grad_iterations` of the
|
| 517 |
+
# `L_cycles` inner iterations propagate gradients; earlier iterations run under
|
| 518 |
+
# `torch.no_grad()` to bound activation memory.
|
| 519 |
+
num_grad_iterations = (
|
| 520 |
+
self.L_bp_cycles_padded[high_cycle_idx] if high_cycle_idx < len(self.L_bp_cycles_padded) else 1
|
| 521 |
+
)
|
| 522 |
+
grad_threshold = self.config.L_cycles - num_grad_iterations
|
| 523 |
+
for low_cycle_idx in range(self.config.L_cycles):
|
| 524 |
+
cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + low_cycle_idx) * num_layers_per_stack
|
| 525 |
+
ctx = nullcontext() if low_cycle_idx >= grad_threshold else torch.no_grad()
|
| 526 |
+
with ctx:
|
| 527 |
+
hidden_states_low_cycle = self.L_module(
|
| 528 |
+
hidden_states_low_cycle.to(hidden_states_high_cycle.device) + hidden_states_high_cycle,
|
| 529 |
+
attention_mask=attention_mask,
|
| 530 |
+
past_key_values=past_key_values,
|
| 531 |
+
position_embeddings=position_embeddings,
|
| 532 |
+
position_ids=position_ids,
|
| 533 |
+
cycle_offset=cycle_offset,
|
| 534 |
+
**kwargs,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + self.config.L_cycles) * num_layers_per_stack
|
| 538 |
+
|
| 539 |
+
hidden_states_high_cycle = self.H_module(
|
| 540 |
+
hidden_states_high_cycle + hidden_states_low_cycle.to(hidden_states_high_cycle.device),
|
| 541 |
+
attention_mask=attention_mask,
|
| 542 |
+
past_key_values=past_key_values,
|
| 543 |
+
position_embeddings=position_embeddings,
|
| 544 |
+
position_ids=position_ids,
|
| 545 |
+
cycle_offset=cycle_offset,
|
| 546 |
+
**kwargs,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
return BaseModelOutputWithPast(
|
| 550 |
+
last_hidden_state=hidden_states_high_cycle,
|
| 551 |
+
past_key_values=past_key_values,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
@auto_docstring
|
| 556 |
+
class HrmTextForCausalLM(HrmTextPreTrainedModel, GenerationMixin):
|
| 557 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 558 |
+
_tp_plan = {"lm_head": "colwise_gather_output"}
|
| 559 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 560 |
+
|
| 561 |
+
def __init__(self, config):
|
| 562 |
+
super().__init__(config)
|
| 563 |
+
self.model = HrmTextModel(config)
|
| 564 |
+
self.vocab_size = config.vocab_size
|
| 565 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 566 |
+
|
| 567 |
+
# Initialize weights and apply final processing
|
| 568 |
+
self.post_init()
|
| 569 |
+
|
| 570 |
+
@can_return_tuple
|
| 571 |
+
@auto_docstring
|
| 572 |
+
def forward(
|
| 573 |
+
self,
|
| 574 |
+
input_ids: torch.LongTensor | None = None,
|
| 575 |
+
attention_mask: torch.Tensor | None = None,
|
| 576 |
+
position_ids: torch.LongTensor | None = None,
|
| 577 |
+
past_key_values: Cache | None = None,
|
| 578 |
+
token_type_ids: torch.LongTensor | None = None,
|
| 579 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 580 |
+
labels: torch.LongTensor | None = None,
|
| 581 |
+
use_cache: bool | None = None,
|
| 582 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 583 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 584 |
+
) -> CausalLMOutputWithPast:
|
| 585 |
+
r"""
|
| 586 |
+
token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
|
| 587 |
+
Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
|
| 588 |
+
form a single bidirectional block; all other positions are causal.
|
| 589 |
+
"""
|
| 590 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 591 |
+
input_ids=input_ids,
|
| 592 |
+
attention_mask=attention_mask,
|
| 593 |
+
position_ids=position_ids,
|
| 594 |
+
past_key_values=past_key_values,
|
| 595 |
+
token_type_ids=token_type_ids,
|
| 596 |
+
inputs_embeds=inputs_embeds,
|
| 597 |
+
use_cache=use_cache,
|
| 598 |
+
**kwargs,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
hidden_states = outputs.last_hidden_state
|
| 602 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 603 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 604 |
+
|
| 605 |
+
loss = None
|
| 606 |
+
if labels is not None:
|
| 607 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 608 |
+
|
| 609 |
+
return CausalLMOutputWithPast(
|
| 610 |
+
loss=loss,
|
| 611 |
+
logits=logits,
|
| 612 |
+
past_key_values=outputs.past_key_values,
|
| 613 |
+
hidden_states=outputs.hidden_states,
|
| 614 |
+
attentions=outputs.attentions,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
@staticmethod
|
| 618 |
+
def create_masks_for_generate(
|
| 619 |
+
config: PreTrainedConfig,
|
| 620 |
+
inputs_embeds: torch.Tensor,
|
| 621 |
+
attention_mask: torch.Tensor | None,
|
| 622 |
+
past_key_values: Cache | None,
|
| 623 |
+
position_ids: torch.Tensor | None,
|
| 624 |
+
token_type_ids: torch.Tensor | None = None,
|
| 625 |
+
is_first_iteration: bool | None = False,
|
| 626 |
+
**kwargs,
|
| 627 |
+
) -> dict:
|
| 628 |
+
mask_kwargs = {
|
| 629 |
+
"config": config,
|
| 630 |
+
"inputs_embeds": inputs_embeds,
|
| 631 |
+
"attention_mask": attention_mask,
|
| 632 |
+
"past_key_values": past_key_values,
|
| 633 |
+
"position_ids": position_ids,
|
| 634 |
+
}
|
| 635 |
+
if token_type_ids is not None and is_first_iteration:
|
| 636 |
+
if config.prefix_lm:
|
| 637 |
+
mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
|
| 638 |
+
else:
|
| 639 |
+
logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
|
| 640 |
+
|
| 641 |
+
return create_masks_for_generate(**mask_kwargs)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
__all__ = ["HrmTextForCausalLM", "HrmTextModel", "HrmTextPreTrainedModel"]
|
quantization_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bits": 4,
|
| 3 |
+
"act_bits": 4,
|
| 4 |
+
"data_type": "mx_fp",
|
| 5 |
+
"act_data_type": "mx_fp",
|
| 6 |
+
"group_size": 32,
|
| 7 |
+
"act_group_size": 32,
|
| 8 |
+
"sym": true,
|
| 9 |
+
"act_sym": true,
|
| 10 |
+
"act_dynamic": true,
|
| 11 |
+
"iters": 0,
|
| 12 |
+
"low_gpu_mem_usage": true,
|
| 13 |
+
"autoround_version": "0.12.3",
|
| 14 |
+
"quant_method": "auto-round",
|
| 15 |
+
"packing_format": "auto_round:llm_compressor"
|
| 16 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": null,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": "<|im_start|>",
|
| 5 |
+
"eos_token": "<|box_end|>",
|
| 6 |
+
"is_local": false,
|
| 7 |
+
"local_files_only": false,
|
| 8 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 9 |
+
"pad_token": "<|endoftext|>",
|
| 10 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 11 |
+
"unk_token": "<|endoftext|>"
|
| 12 |
+
}
|