INC4AI commited on
Commit
04464f3
·
verified ·
1 Parent(s): 4faa170

Upload quantized model HRM-Text-1B-autoround-MXFP4

Browse files
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model:
3
+ - sapientinc/HRM-Text-1B
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - quantized
7
+ - mxfp4
8
+ - autoround
9
+ - low-bit-open-llm-leaderboard
10
+ ---
11
+
12
+ # HRM-Text-1B-autoround-MXFP4
13
+
14
+ ## Model Details
15
+
16
+ This model is a MXFP4 (Microscaling FP4) quantization of [sapientinc/HRM-Text-1B](https://huggingface.co/sapientinc/HRM-Text-1B) generated by [AutoRound](https://github.com/intel/auto-round). Please follow the license of the original model.
17
+
18
+ ## Quantization Details
19
+
20
+ | Attribute | Value |
21
+ |-----------|-------|
22
+ | Base Model | [sapientinc/HRM-Text-1B](https://huggingface.co/sapientinc/HRM-Text-1B) |
23
+ | Quantization Tool | [AutoRound](https://github.com/intel/auto-round) |
24
+ | Quantization Scheme | MXFP4 |
25
+ | Original Size | 2256 MB |
26
+ | Quantized Size | 886 MB |
27
+
28
+ ## Evaluation Results
29
+
30
+ | Task | Accuracy |
31
+ |------|----------|
32
+ | hellaswag | 0.2504 |
33
+ | mmlu | 0.2309 |
34
+ | piqa | 0.4951 |
35
+
36
+ ## How to Use
37
+
38
+ ### HF Usage
39
+
40
+ **Step 1: Install [AutoRound](https://github.com/intel/auto-round)**
41
+
42
+ ```bash
43
+ pip install auto-round
44
+ ```
45
+
46
+ **Step 2: Load and run the quantized model**
47
+
48
+ ```python
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
+
51
+ model_name = "HRM-Text-1B-autoround-MXFP4"
52
+
53
+ # load the tokenizer and the model
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
56
+
57
+ # prepare the model input
58
+ prompt = "Write a quick sort algorithm."
59
+ messages = [{"role": "user", "content": prompt}]
60
+ text = tokenizer.apply_chat_template(
61
+ messages,
62
+ tokenize=False,
63
+ add_generation_prompt=True,
64
+ )
65
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
66
+
67
+ # conduct text completion
68
+ generated_ids = model.generate(**model_inputs, max_new_tokens=512)
69
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
70
+
71
+ content = tokenizer.decode(output_ids, skip_special_tokens=True)
72
+ print("content:", content)
73
+ ```
74
+
75
+ ### VLLM Usage
76
+
77
+ ```bash
78
+ vllm serve HRM-Text-1B-autoround-MXFP4 \
79
+ --trust-remote-code \
80
+ --dtype bfloat16 \
81
+ --tensor_parallel_size 1
82
+ ```
83
+
84
+ If you encounter any issues, feel free to open an issue on the [AutoRound GitHub repo](https://github.com/intel/auto-round/issues) or provide feedback on the [Low-Bit Open LLM Leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard).
85
+
86
+ ## Ethical Considerations and Limitations
87
+
88
+ The model can produce factually incorrect output, and should not be relied on to produce factually accurate information. Because of the limitations of the pretrained model and the finetuning datasets, it is possible that this model could generate lewd, biased or otherwise offensive outputs.
89
+ Therefore, before deploying any applications of the model, developers should perform safety testing.
90
+
91
+ ## Caveats and Recommendations
92
+
93
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model.
94
+ Here are a couple of useful links to learn more about Intel's AI software:
95
+
96
+ - [Intel Neural Compressor](https://github.com/intel/neural-compressor)
97
+ - [AutoRound](https://github.com/intel/auto-round)
98
+
99
+ ## Disclaimer
100
+
101
+ The license on this model does not constitute legal advice. We are not responsible for the actions of third parties who use this model. Please consult an attorney before using this model for commercial purposes.
102
+
103
+ ## Cite
104
+
105
+ ```
106
+ @article{cheng2023optimize,
107
+ title={Optimize weight rounding via signed gradient descent for the quantization of llms},
108
+ author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi},
109
+ journal={arXiv preprint arXiv:2309.05516},
110
+ year={2023}
111
+ }
112
+ ```
113
+
114
+ [arxiv](https://arxiv.org/abs/2309.05516) [github](https://github.com/intel/auto-round)
115
+
116
+ ---
117
+
118
+ *This model is part of the [Intel Low-Bit Open LLM Leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) initiative.*
config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "H_cycles": 2,
3
+ "L_bp_cycles": [
4
+ 0,
5
+ 3
6
+ ],
7
+ "L_cycles": 3,
8
+ "architectures": [
9
+ "HrmTextForCausalLM"
10
+ ],
11
+ "attention_bias": false,
12
+ "attention_dropout": 0.0,
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_hrm_text.HrmTextConfig",
15
+ "AutoModel": "modeling_hrm_text.HrmTextModel",
16
+ "AutoModelForCausalLM": "modeling_hrm_text.HrmTextForCausalLM"
17
+ },
18
+ "bos_token_id": 6,
19
+ "dtype": "bfloat16",
20
+ "embedding_scale": 39.191835884530846,
21
+ "eos_token_id": 11,
22
+ "head_dim": 128,
23
+ "hidden_act": "silu",
24
+ "hidden_size": 1536,
25
+ "initializer_range": 0.025515518153991442,
26
+ "intermediate_size": 4096,
27
+ "max_position_embeddings": 4096,
28
+ "mlp_bias": false,
29
+ "model_type": "hrm_text",
30
+ "num_attention_heads": 12,
31
+ "num_hidden_layers": 128,
32
+ "num_key_value_heads": 12,
33
+ "num_layers_per_stack": 16,
34
+ "pad_token_id": 5,
35
+ "prefix_lm": true,
36
+ "quantization_config": {
37
+ "act_bits": 4,
38
+ "act_data_type": "mx_fp",
39
+ "act_dynamic": true,
40
+ "act_group_size": 32,
41
+ "act_sym": true,
42
+ "autoround_version": "0.12.3",
43
+ "bits": 4,
44
+ "data_type": "mx_fp",
45
+ "group_size": 32,
46
+ "iters": 0,
47
+ "low_gpu_mem_usage": true,
48
+ "packing_format": "auto_round:llm_compressor",
49
+ "quant_method": "auto-round",
50
+ "sym": true
51
+ },
52
+ "rms_norm_eps": 1e-06,
53
+ "rope_parameters": {
54
+ "rope_theta": 10000.0,
55
+ "rope_type": "default"
56
+ },
57
+ "tie_word_embeddings": false,
58
+ "transformers_version": "5.9.0",
59
+ "use_cache": true,
60
+ "vocab_size": 65536
61
+ }
configuration_hrm_text.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_hrm_text.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from huggingface_hub.dataclasses import strict
22
+
23
+ from transformers.configuration_utils import PreTrainedConfig
24
+ from transformers.modeling_rope_utils import RopeParameters
25
+ from transformers.utils import auto_docstring
26
+ from transformers.utils.generic import is_flash_attention_requested, split_attention_implementation
27
+ from transformers.utils.type_validators import interval
28
+
29
+
30
+ @auto_docstring(checkpoint="sapientinc/HRM-Text-1B")
31
+ @strict
32
+ class HrmTextConfig(PreTrainedConfig):
33
+ r"""
34
+ H_cycles (`int`, *optional*, defaults to 2):
35
+ Number of high-level cycles.
36
+ L_cycles (`int`, *optional*, defaults to 3):
37
+ Number of low-level cycles per H-cycle.
38
+ L_bp_cycles (`list[int]`, *optional*, defaults to `[2]`):
39
+ Training-time gradient-routing list; left-padded with `1`s up to `L_cycles` inside the model.
40
+ Inference-time no-op.
41
+ embedding_scale (`float`, *optional*):
42
+ Token-embedding multiplier. If `None`, defaults to `1 / initializer_range`.
43
+ prefix_lm (`bool`, *optional*, defaults to `True`):
44
+ Instruction tokens attend bidirectionally, response tokens attend causally.
45
+ num_layers_per_stack (`int`, *optional*):
46
+ Real number of transformer blocks inside each
47
+ of the H / L stacks. Set automatically on first construction: the value passed as
48
+ `num_hidden_layers` is remembered here and `num_hidden_layers` is then rewritten to
49
+ `num_layers_per_stack * H_cycles * (L_cycles + 1)` so that
50
+ `DynamicCache(config=...)` pre-allocates one slot per unique attention invocation
51
+ under the recurrent forward. Do not set this directly on first construction — pass
52
+ the real per-stack count as `num_hidden_layers` and let `__post_init__` split it.
53
+ """
54
+
55
+ model_type = "hrm_text"
56
+ keys_to_ignore_at_inference = ["past_key_values"]
57
+
58
+ base_model_tp_plan = {
59
+ **{f"{stack}.layers.*.self_attn.q_proj": "colwise" for stack in ("L_module", "H_module")},
60
+ **{f"{stack}.layers.*.self_attn.k_proj": "colwise" for stack in ("L_module", "H_module")},
61
+ **{f"{stack}.layers.*.self_attn.v_proj": "colwise" for stack in ("L_module", "H_module")},
62
+ **{f"{stack}.layers.*.self_attn.gate_proj": "colwise" for stack in ("L_module", "H_module")},
63
+ **{f"{stack}.layers.*.self_attn.o_proj": "rowwise" for stack in ("L_module", "H_module")},
64
+ **{f"{stack}.layers.*.mlp.gate_proj": "colwise" for stack in ("L_module", "H_module")},
65
+ **{f"{stack}.layers.*.mlp.up_proj": "colwise" for stack in ("L_module", "H_module")},
66
+ **{f"{stack}.layers.*.mlp.down_proj": "rowwise" for stack in ("L_module", "H_module")},
67
+ }
68
+ base_model_pp_plan = {
69
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
70
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
71
+ "norm": (["hidden_states"], ["hidden_states"]),
72
+ }
73
+
74
+ vocab_size: int = 151808
75
+ hidden_size: int = 1536
76
+ intermediate_size: int = 4096
77
+ num_hidden_layers: int = 16
78
+ num_attention_heads: int = 12
79
+ hidden_act: str = "silu"
80
+ max_position_embeddings: int = 2048
81
+ initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
82
+ rms_norm_eps: float = 1e-6
83
+ use_cache: bool = True
84
+ pad_token_id: int | None = None
85
+ bos_token_id: int | None = None
86
+ eos_token_id: int | list[int] | None = None
87
+ tie_word_embeddings: bool = False
88
+ rope_parameters: RopeParameters | dict | None = None
89
+ attention_bias: bool = False
90
+ attention_dropout: int | float | None = 0.0
91
+ mlp_bias: bool = False
92
+ head_dim: int = 128
93
+
94
+ H_cycles: int = 2
95
+ L_cycles: int = 3
96
+ L_bp_cycles: list[int] | None = None
97
+ embedding_scale: float | None = None
98
+ prefix_lm: bool = True
99
+ num_layers_per_stack: int | None = None # Usually inferred in post init
100
+
101
+ def __post_init__(self, **kwargs):
102
+ if self.L_bp_cycles is None:
103
+ # Default `[2]` = backprop only the last 2 L-iterations per H-cycle (training-time
104
+ # gradient-routing knob). Left-padding to length `L_cycles` is performed inside
105
+ # [`HrmTextModel`] since it depends on `L_cycles`.
106
+ self.L_bp_cycles = [2]
107
+
108
+ if self.embedding_scale is None:
109
+ self.embedding_scale = 1.0 / self.initializer_range
110
+
111
+ if self.num_layers_per_stack is None:
112
+ # Initial construction, or legacy checkpoint where `num_hidden_layers` carries the
113
+ # real per-stack count: remember that value and rewrite `num_hidden_layers` to the
114
+ # inflated total, so standard HF cache allocation gives us one slot per unique
115
+ # attention invocation. Serialised configs round-trip as (inflated, real) pairs.
116
+ self.num_layers_per_stack = self.num_hidden_layers
117
+ self.num_hidden_layers = self.num_layers_per_stack * self.H_cycles * (self.L_cycles + 1)
118
+
119
+ super().__post_init__(**kwargs)
120
+
121
+ def validate_architecture(self):
122
+ """Part of `@strict`-powered validation. Validates the architecture of the config."""
123
+ if self.hidden_size % self.num_attention_heads != 0:
124
+ raise ValueError(
125
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
126
+ f"heads ({self.num_attention_heads})."
127
+ )
128
+
129
+ @property
130
+ def _attn_implementation(self):
131
+ return self._attn_implementation_internal
132
+
133
+ @_attn_implementation.setter
134
+ def _attn_implementation(self, value: str | dict | None):
135
+ if value is not None and self.prefix_lm:
136
+ _, base_implementation = split_attention_implementation(value)
137
+ if is_flash_attention_requested(requested_attention_implementation=base_implementation):
138
+ raise ValueError(
139
+ f"`attn_implementation={value!r}` is not supported when "
140
+ "`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
141
+ "overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
142
+ )
143
+ PreTrainedConfig._attn_implementation.__set__(self, value)
144
+
145
+
146
+ __all__ = ["HrmTextConfig"]
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 6,
4
+ "do_sample": true,
5
+ "eos_token_id": 11,
6
+ "pad_token_id": 5,
7
+ "transformers_version": "5.9.0"
8
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:905484c08c12cf186265a02a23b6f780366999d2caf2b6614fafd6b684fd2499
3
+ size 924093024
modeling_hrm_text.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_hrm_text.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+ from contextlib import nullcontext
23
+ from typing import Optional
24
+
25
+ import torch
26
+ from torch import nn
27
+
28
+ from transformers import initialization as init
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.configuration_utils import PreTrainedConfig
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.integrations import use_kernel_func_from_hub, use_kernelized_func
34
+ from transformers.masking_utils import create_causal_mask, create_masks_for_generate
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import auto_docstring, can_return_tuple, logging
41
+ from transformers.utils.generic import (
42
+ TransformersKwargs,
43
+ is_flash_attention_requested,
44
+ maybe_autocast,
45
+ merge_with_config_defaults,
46
+ split_attention_implementation,
47
+ )
48
+ from transformers.utils.output_capturing import capture_outputs
49
+ from .configuration_hrm_text import HrmTextConfig
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ class HrmTextRMSNorm(torch.nn.Module):
56
+ def __init__(self, eps: float = 1e-6):
57
+ super().__init__()
58
+ self.eps = eps
59
+
60
+ def _norm(self, x):
61
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
62
+
63
+ def forward(self, x):
64
+ return self._norm(x.float()).type_as(x)
65
+
66
+ def extra_repr(self):
67
+ return f"eps={self.eps}"
68
+
69
+
70
+ class HrmTextMLP(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.config = config
74
+ self.hidden_size = config.hidden_size
75
+ self.intermediate_size = config.intermediate_size
76
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
77
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
78
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
79
+ self.act_fn = ACT2FN[config.hidden_act]
80
+
81
+ def forward(self, x):
82
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
83
+ return down_proj
84
+
85
+
86
+ def rotate_half(x):
87
+ """Rotates half the hidden dims of the input."""
88
+ x1 = x[..., : x.shape[-1] // 2]
89
+ x2 = x[..., x.shape[-1] // 2 :]
90
+ return torch.cat((-x2, x1), dim=-1)
91
+
92
+
93
+ @use_kernel_func_from_hub("rotary_pos_emb")
94
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
95
+ """Applies Rotary Position Embedding to the query and key tensors.
96
+
97
+ Args:
98
+ q (`torch.Tensor`): The query tensor.
99
+ k (`torch.Tensor`): The key tensor.
100
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
101
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
102
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
103
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
104
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
105
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
106
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
107
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
108
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
109
+ Returns:
110
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
111
+ """
112
+ cos = cos.unsqueeze(unsqueeze_dim)
113
+ sin = sin.unsqueeze(unsqueeze_dim)
114
+ q_embed = (q * cos) + (rotate_half(q) * sin)
115
+ k_embed = (k * cos) + (rotate_half(k) * sin)
116
+ return q_embed, k_embed
117
+
118
+
119
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
120
+ """
121
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
122
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
123
+ """
124
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
125
+ if n_rep == 1:
126
+ return hidden_states
127
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
128
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
129
+
130
+
131
+ def eager_attention_forward(
132
+ module: nn.Module,
133
+ query: torch.Tensor,
134
+ key: torch.Tensor,
135
+ value: torch.Tensor,
136
+ attention_mask: torch.Tensor | None,
137
+ scaling: float,
138
+ dropout: float = 0.0,
139
+ **kwargs: Unpack[TransformersKwargs],
140
+ ):
141
+ key_states = repeat_kv(key, module.num_key_value_groups)
142
+ value_states = repeat_kv(value, module.num_key_value_groups)
143
+
144
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
145
+ if attention_mask is not None:
146
+ attn_weights = attn_weights + attention_mask
147
+
148
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
149
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
150
+ attn_output = torch.matmul(attn_weights, value_states)
151
+ attn_output = attn_output.transpose(1, 2).contiguous()
152
+
153
+ return attn_output, attn_weights
154
+
155
+
156
+ @use_kernelized_func(apply_rotary_pos_emb)
157
+ class HrmTextAttention(nn.Module):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(self, config: HrmTextConfig, layer_idx: int):
161
+ super().__init__()
162
+ self.config = config
163
+ self.layer_idx = layer_idx
164
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
+ self.num_key_value_groups = 1 # Uses MHA instead of GQA
166
+ self.scaling = self.head_dim**-0.5
167
+ self.attention_dropout = config.attention_dropout
168
+ self.is_causal = True
169
+
170
+ self.q_proj = nn.Linear(
171
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
172
+ )
173
+ self.k_proj = nn.Linear(
174
+ config.hidden_size,
175
+ config.num_attention_heads * self.head_dim,
176
+ bias=config.attention_bias,
177
+ )
178
+ self.v_proj = nn.Linear(
179
+ config.hidden_size,
180
+ config.num_attention_heads * self.head_dim,
181
+ bias=config.attention_bias,
182
+ )
183
+ self.o_proj = nn.Linear(
184
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
185
+ )
186
+ # Additional sigmoid gate applied at the end
187
+ self.gate_proj = nn.Linear(
188
+ config.hidden_size,
189
+ config.num_attention_heads * self.head_dim,
190
+ bias=config.attention_bias,
191
+ )
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
197
+ attention_mask: torch.Tensor | None = None,
198
+ past_key_values: Cache | None = None,
199
+ cycle_offset: int = 0,
200
+ **kwargs: Unpack[TransformersKwargs],
201
+ ) -> tuple[torch.Tensor, torch.Tensor]:
202
+ input_shape = hidden_states.shape[:-1]
203
+ hidden_shape = (*input_shape, -1, self.head_dim)
204
+
205
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
206
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
207
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
208
+ gate_states = self.gate_proj(hidden_states).view(hidden_shape)
209
+
210
+ cos, sin = position_embeddings
211
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
212
+
213
+ if past_key_values is not None:
214
+ # Adjust cache slot by `cycle_offset` which is determined by it's current recurrent step through the stacks
215
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx + cycle_offset)
216
+
217
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
218
+ self.config._attn_implementation, eager_attention_forward
219
+ )
220
+ attn_output, attn_weights = attention_interface(
221
+ self,
222
+ query_states,
223
+ key_states,
224
+ value_states,
225
+ attention_mask,
226
+ dropout=0.0 if not self.training else self.attention_dropout,
227
+ scaling=self.scaling,
228
+ **kwargs,
229
+ )
230
+
231
+ # Additional sigmoid gating (similar to Qwen3Next)
232
+ attn_output = torch.sigmoid(gate_states) * attn_output
233
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
234
+ attn_output = self.o_proj(attn_output)
235
+ return attn_output, attn_weights
236
+
237
+
238
+ class HrmTextDecoderLayer(GradientCheckpointingLayer):
239
+ def __init__(self, config: HrmTextConfig, layer_idx: int):
240
+ super().__init__()
241
+ self.hidden_size = config.hidden_size
242
+
243
+ self.self_attn = HrmTextAttention(config=config, layer_idx=layer_idx)
244
+
245
+ self.mlp = HrmTextMLP(config)
246
+ self.input_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
247
+ self.post_attention_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ attention_mask: torch.Tensor | None = None,
253
+ position_ids: torch.LongTensor | None = None,
254
+ past_key_values: Cache | None = None,
255
+ use_cache: bool | None = False,
256
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
257
+ **kwargs: Unpack[TransformersKwargs],
258
+ ) -> torch.Tensor:
259
+ residual = hidden_states
260
+ hidden_states = self.input_layernorm(hidden_states)
261
+ # Self Attention
262
+ hidden_states, _ = self.self_attn(
263
+ hidden_states=hidden_states,
264
+ attention_mask=attention_mask,
265
+ position_ids=position_ids,
266
+ past_key_values=past_key_values,
267
+ use_cache=use_cache,
268
+ position_embeddings=position_embeddings,
269
+ **kwargs,
270
+ )
271
+ hidden_states = residual + hidden_states
272
+
273
+ # Fully Connected
274
+ residual = hidden_states
275
+ hidden_states = self.post_attention_layernorm(hidden_states)
276
+ hidden_states = self.mlp(hidden_states)
277
+ hidden_states = residual + hidden_states
278
+ return hidden_states
279
+
280
+
281
+ class HrmTextStack(nn.Module):
282
+ """A single transformer stack — used twice inside, once as H module and once as L module"""
283
+
284
+ def __init__(self, config: HrmTextConfig):
285
+ super().__init__()
286
+ self.layers = nn.ModuleList(
287
+ [HrmTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers_per_stack)]
288
+ )
289
+ self.final_norm = HrmTextRMSNorm(eps=config.rms_norm_eps)
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ attention_mask: torch.Tensor | None = None,
295
+ past_key_values: Cache | None = None,
296
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
297
+ cycle_offset: int = 0,
298
+ **kwargs: Unpack[TransformersKwargs],
299
+ ) -> torch.Tensor:
300
+ for layer in self.layers:
301
+ hidden_states = layer(
302
+ hidden_states,
303
+ attention_mask=attention_mask,
304
+ past_key_values=past_key_values,
305
+ position_embeddings=position_embeddings,
306
+ cycle_offset=cycle_offset,
307
+ **kwargs,
308
+ )
309
+ return self.final_norm(hidden_states)
310
+
311
+
312
+ @auto_docstring
313
+ class HrmTextPreTrainedModel(PreTrainedModel):
314
+ config: HrmTextConfig
315
+ base_model_prefix = "model"
316
+ supports_gradient_checkpointing = True
317
+ _no_split_modules = ["HrmTextDecoderLayer"]
318
+ _skip_keys_device_placement = ["past_key_values"]
319
+ _supports_flash_attn = True
320
+ _supports_sdpa = True
321
+ _supports_flex_attn = True
322
+
323
+ _can_compile_fullgraph = True
324
+ _supports_attention_backend = True
325
+ _can_record_outputs = {
326
+ "hidden_states": HrmTextDecoderLayer,
327
+ "attentions": HrmTextAttention,
328
+ }
329
+
330
+ def _check_and_adjust_attn_implementation(
331
+ self, attn_implementation: str | None, is_init_check: bool = False, allow_all_kernels: bool = False
332
+ ) -> str:
333
+ if attn_implementation is not None and self.config.prefix_lm:
334
+ _, base_implementation = split_attention_implementation(attn_implementation)
335
+ if is_flash_attention_requested(requested_attention_implementation=base_implementation):
336
+ raise ValueError(
337
+ f"`attn_implementation={attn_implementation!r}` is not supported when "
338
+ "`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
339
+ "overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
340
+ )
341
+ return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check, allow_all_kernels)
342
+
343
+ @torch.no_grad()
344
+ def _init_weights(self, module):
345
+ super()._init_weights(module)
346
+ if isinstance(module, HrmTextModel):
347
+ init.zeros_(module.z_L_init)
348
+ # `z_L_init` is the frozen low-cycle initial state and never trains.
349
+ module.z_L_init.requires_grad_(False) # trf-ignore: TRF012
350
+
351
+
352
+ class HrmTextRotaryEmbedding(nn.Module):
353
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
354
+
355
+ def __init__(self, config: HrmTextConfig, device=None):
356
+ super().__init__()
357
+ self.max_seq_len_cached = config.max_position_embeddings
358
+ self.original_max_seq_len = config.max_position_embeddings
359
+
360
+ self.config = config
361
+
362
+ self.rope_type = self.config.rope_parameters["rope_type"]
363
+ rope_init_fn: Callable = self.compute_default_rope_parameters
364
+ if self.rope_type != "default":
365
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
366
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
367
+
368
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
369
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
370
+
371
+ @staticmethod
372
+ def compute_default_rope_parameters(
373
+ config: HrmTextConfig | None = None,
374
+ device: Optional["torch.device"] = None,
375
+ seq_len: int | None = None,
376
+ ) -> tuple["torch.Tensor", float]:
377
+ """
378
+ Computes the inverse frequencies according to the original RoPE implementation
379
+ Args:
380
+ config ([`~transformers.PreTrainedConfig`]):
381
+ The model configuration.
382
+ device (`torch.device`):
383
+ The device to use for initialization of the inverse frequencies.
384
+ seq_len (`int`, *optional*):
385
+ The current sequence length. Unused for this type of RoPE.
386
+ Returns:
387
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
388
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
389
+ """
390
+ base = config.rope_parameters["rope_theta"]
391
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
392
+
393
+ attention_factor = 1.0 # Unused in this type of RoPE
394
+
395
+ # Compute the inverse frequencies
396
+ inv_freq = 1.0 / (
397
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
398
+ )
399
+ return inv_freq, attention_factor
400
+
401
+ @torch.no_grad()
402
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
403
+ def forward(self, x, position_ids):
404
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
405
+ position_ids_expanded = position_ids[:, None, :].float()
406
+
407
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
408
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
409
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
410
+ emb = torch.cat((freqs, freqs), dim=-1)
411
+ cos = emb.cos() * self.attention_scaling
412
+ sin = emb.sin() * self.attention_scaling
413
+
414
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
415
+
416
+
417
+ @auto_docstring
418
+ class HrmTextModel(HrmTextPreTrainedModel):
419
+ def __init__(self, config: HrmTextConfig):
420
+ super().__init__(config)
421
+ self.padding_idx = config.pad_token_id
422
+ self.vocab_size = config.vocab_size
423
+
424
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
425
+ self.rotary_emb = HrmTextRotaryEmbedding(config=config)
426
+ self.gradient_checkpointing = False
427
+
428
+ self.embedding_scale = config.embedding_scale
429
+
430
+ # Recursive module structures
431
+ self.L_module = HrmTextStack(config)
432
+ self.H_module = HrmTextStack(config)
433
+ # Initial state for the low cycle module
434
+ self.z_L_init = nn.Parameter(torch.zeros(config.hidden_size), requires_grad=False)
435
+
436
+ raw_bp = list(config.L_bp_cycles)
437
+ self.L_bp_cycles_padded = [1] * max(0, config.H_cycles - len(raw_bp)) + raw_bp
438
+
439
+ # Initialize weights and apply final processing
440
+ self.post_init()
441
+
442
+ @merge_with_config_defaults
443
+ @capture_outputs
444
+ @auto_docstring
445
+ def forward(
446
+ self,
447
+ input_ids: torch.LongTensor | None = None,
448
+ attention_mask: torch.Tensor | None = None,
449
+ position_ids: torch.LongTensor | None = None,
450
+ past_key_values: Cache | None = None,
451
+ token_type_ids: torch.LongTensor | None = None,
452
+ inputs_embeds: torch.FloatTensor | None = None,
453
+ use_cache: bool | None = None,
454
+ **kwargs: Unpack[TransformersKwargs],
455
+ ) -> BaseModelOutputWithPast:
456
+ r"""
457
+ token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
458
+ Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
459
+ form a single bidirectional block; all other positions are causal.
460
+ """
461
+ if (input_ids is None) ^ (inputs_embeds is not None):
462
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
463
+
464
+ if inputs_embeds is None:
465
+ inputs_embeds = self.embed_tokens(input_ids)
466
+ # Additional scaling on the input embeds
467
+ inputs_embeds = inputs_embeds * self.embedding_scale
468
+
469
+ if use_cache and past_key_values is None:
470
+ past_key_values = DynamicCache(config=self.config)
471
+
472
+ if position_ids is None:
473
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
474
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
475
+ position_ids = position_ids.unsqueeze(0)
476
+
477
+ # Create mask with optional prefix-based bidirectionality
478
+ mask_kwargs = {
479
+ "config": self.config,
480
+ "inputs_embeds": inputs_embeds,
481
+ "attention_mask": attention_mask,
482
+ "past_key_values": past_key_values,
483
+ "position_ids": position_ids,
484
+ }
485
+ is_first_iteration = past_key_values is None or not past_key_values.is_initialized
486
+ if token_type_ids is not None and is_first_iteration:
487
+ if self.config.prefix_lm:
488
+ mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
489
+ else:
490
+ logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
491
+
492
+ attention_mask = create_causal_mask(**mask_kwargs)
493
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
494
+
495
+ # Hierarchical (H/L)-cycle recurrence
496
+ #
497
+ # `z_H` - slow / high-level state
498
+ hidden_states_high_cycle = inputs_embeds
499
+ # `z_L` - fast / low-level state
500
+ hidden_states_low_cycle = (
501
+ self.z_L_init.to(dtype=hidden_states_high_cycle.dtype, device=hidden_states_high_cycle.device)
502
+ .expand_as(hidden_states_high_cycle)
503
+ .contiguous()
504
+ )
505
+
506
+ # Cache-slot layout under the recurrent forward:
507
+ #
508
+ # slot(h, l, layer) = (h * (L_cycles + 1) + l) * num_layers_per_stack + layer
509
+ # ^— L-stack invocation at (h, l)
510
+ # slot(h, H, layer) = (h * (L_cycles + 1) + L_cycles) * num_layers_per_stack + layer
511
+ # ^— trailing H-stack invocation
512
+ #
513
+ # That totals `num_layers_per_stack * H_cycles * (L_cycles + 1)` slots, i.e. the `config.num_hidden_layers`.
514
+ num_layers_per_stack = self.config.num_layers_per_stack
515
+ for high_cycle_idx in range(self.config.H_cycles):
516
+ # `L_bp_cycles` k-step grad trick: only the trailing `num_grad_iterations` of the
517
+ # `L_cycles` inner iterations propagate gradients; earlier iterations run under
518
+ # `torch.no_grad()` to bound activation memory.
519
+ num_grad_iterations = (
520
+ self.L_bp_cycles_padded[high_cycle_idx] if high_cycle_idx < len(self.L_bp_cycles_padded) else 1
521
+ )
522
+ grad_threshold = self.config.L_cycles - num_grad_iterations
523
+ for low_cycle_idx in range(self.config.L_cycles):
524
+ cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + low_cycle_idx) * num_layers_per_stack
525
+ ctx = nullcontext() if low_cycle_idx >= grad_threshold else torch.no_grad()
526
+ with ctx:
527
+ hidden_states_low_cycle = self.L_module(
528
+ hidden_states_low_cycle.to(hidden_states_high_cycle.device) + hidden_states_high_cycle,
529
+ attention_mask=attention_mask,
530
+ past_key_values=past_key_values,
531
+ position_embeddings=position_embeddings,
532
+ position_ids=position_ids,
533
+ cycle_offset=cycle_offset,
534
+ **kwargs,
535
+ )
536
+
537
+ cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + self.config.L_cycles) * num_layers_per_stack
538
+
539
+ hidden_states_high_cycle = self.H_module(
540
+ hidden_states_high_cycle + hidden_states_low_cycle.to(hidden_states_high_cycle.device),
541
+ attention_mask=attention_mask,
542
+ past_key_values=past_key_values,
543
+ position_embeddings=position_embeddings,
544
+ position_ids=position_ids,
545
+ cycle_offset=cycle_offset,
546
+ **kwargs,
547
+ )
548
+
549
+ return BaseModelOutputWithPast(
550
+ last_hidden_state=hidden_states_high_cycle,
551
+ past_key_values=past_key_values,
552
+ )
553
+
554
+
555
+ @auto_docstring
556
+ class HrmTextForCausalLM(HrmTextPreTrainedModel, GenerationMixin):
557
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
558
+ _tp_plan = {"lm_head": "colwise_gather_output"}
559
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
560
+
561
+ def __init__(self, config):
562
+ super().__init__(config)
563
+ self.model = HrmTextModel(config)
564
+ self.vocab_size = config.vocab_size
565
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
566
+
567
+ # Initialize weights and apply final processing
568
+ self.post_init()
569
+
570
+ @can_return_tuple
571
+ @auto_docstring
572
+ def forward(
573
+ self,
574
+ input_ids: torch.LongTensor | None = None,
575
+ attention_mask: torch.Tensor | None = None,
576
+ position_ids: torch.LongTensor | None = None,
577
+ past_key_values: Cache | None = None,
578
+ token_type_ids: torch.LongTensor | None = None,
579
+ inputs_embeds: torch.FloatTensor | None = None,
580
+ labels: torch.LongTensor | None = None,
581
+ use_cache: bool | None = None,
582
+ logits_to_keep: int | torch.Tensor = 0,
583
+ **kwargs: Unpack[TransformersKwargs],
584
+ ) -> CausalLMOutputWithPast:
585
+ r"""
586
+ token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
587
+ Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
588
+ form a single bidirectional block; all other positions are causal.
589
+ """
590
+ outputs: BaseModelOutputWithPast = self.model(
591
+ input_ids=input_ids,
592
+ attention_mask=attention_mask,
593
+ position_ids=position_ids,
594
+ past_key_values=past_key_values,
595
+ token_type_ids=token_type_ids,
596
+ inputs_embeds=inputs_embeds,
597
+ use_cache=use_cache,
598
+ **kwargs,
599
+ )
600
+
601
+ hidden_states = outputs.last_hidden_state
602
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
603
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
604
+
605
+ loss = None
606
+ if labels is not None:
607
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
608
+
609
+ return CausalLMOutputWithPast(
610
+ loss=loss,
611
+ logits=logits,
612
+ past_key_values=outputs.past_key_values,
613
+ hidden_states=outputs.hidden_states,
614
+ attentions=outputs.attentions,
615
+ )
616
+
617
+ @staticmethod
618
+ def create_masks_for_generate(
619
+ config: PreTrainedConfig,
620
+ inputs_embeds: torch.Tensor,
621
+ attention_mask: torch.Tensor | None,
622
+ past_key_values: Cache | None,
623
+ position_ids: torch.Tensor | None,
624
+ token_type_ids: torch.Tensor | None = None,
625
+ is_first_iteration: bool | None = False,
626
+ **kwargs,
627
+ ) -> dict:
628
+ mask_kwargs = {
629
+ "config": config,
630
+ "inputs_embeds": inputs_embeds,
631
+ "attention_mask": attention_mask,
632
+ "past_key_values": past_key_values,
633
+ "position_ids": position_ids,
634
+ }
635
+ if token_type_ids is not None and is_first_iteration:
636
+ if config.prefix_lm:
637
+ mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
638
+ else:
639
+ logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
640
+
641
+ return create_masks_for_generate(**mask_kwargs)
642
+
643
+
644
+ __all__ = ["HrmTextForCausalLM", "HrmTextModel", "HrmTextPreTrainedModel"]
quantization_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bits": 4,
3
+ "act_bits": 4,
4
+ "data_type": "mx_fp",
5
+ "act_data_type": "mx_fp",
6
+ "group_size": 32,
7
+ "act_group_size": 32,
8
+ "sym": true,
9
+ "act_sym": true,
10
+ "act_dynamic": true,
11
+ "iters": 0,
12
+ "low_gpu_mem_usage": true,
13
+ "autoround_version": "0.12.3",
14
+ "quant_method": "auto-round",
15
+ "packing_format": "auto_round:llm_compressor"
16
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|im_start|>",
5
+ "eos_token": "<|box_end|>",
6
+ "is_local": false,
7
+ "local_files_only": false,
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "Qwen2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }