abcd1927 commited on
Commit
0e4353b
·
0 Parent(s):

HRM-Text-1B

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ banner.jpg filter=lfs diff=lfs merge=lfs -text
37
+ benchmark_scatter.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - hrm
9
+ - hierarchical-reasoning
10
+ - prefix-lm
11
+ - base-model
12
+ ---
13
+
14
+ ![HRM-Text banner](banner.jpg)
15
+
16
+ ![Benchmark scatter: FLOPs and tokens vs benchmark average for HRM-Text-1B vs comparable models](benchmark_scatter.png)
17
+
18
+ <p align="center">
19
+ <a href="https://github.com/sapientinc/HRM-Text"><img alt="GitHub" src="https://img.shields.io/badge/GitHub-sapientinc%2FHRM--Text-181717?logo=github&logoColor=white"></a>
20
+ </p>
21
+
22
+ # HRM-Text-1B
23
+
24
+ A 1 B-parameter base language model built on the **Hierarchical Reasoning Model (HRM)** architecture, trained from scratch on a curated text corpus by Sapient Intelligence.
25
+
26
+ HRM is a dual-timescale recurrent architecture: two Transformer modules (H = high-level / slow, L = low-level / fast) iterate over the same input embeddings for `H_cycles × L_cycles` steps, with additive state injection (`z_L + z_H`). This gives effectively unbounded compute depth at bounded parameter count.
27
+
28
+ ## Disclaimer
29
+
30
+ This is a **base** model. It is pre-trained on a PrefixLM objective with condition prefix tokens and has **not** been instruction-tuned, RLHF'd, or otherwise post-trained. For any serious downstream use we recommend post-training (SFT and/or RL) on task-specific data; the base checkpoint is meant as a starting point, not a finished assistant.
31
+
32
+ Practical guidance for prompting the raw base model:
33
+
34
+ - **NLP tasks (classification, extraction, structured output, short-form QA)**: use the `direct` condition with 2–8 few-shot in-context examples. `direct` + few-shot is the strongest zero-extra-training setup we have measured; pure zero-shot is noticeably weaker.
35
+ - **Reasoning / math / open-ended generation**: use the **composite condition** `synth,cot`. This is *one* composite prefix, not two alternatives — at tokenization time the comma-separated tags are mapped to their prefix tokens and concatenated, in order, into a single prefix block. So `synth,cot` produces the two-token prefix `<|quad_end|><|object_ref_end|>` (synth first, then cot), wrapped in the usual `<|im_start|>` … `<|im_end|>` envelope. Under this composite the model exhibits some chain-of-thought / instruct-like behavior — enough to answer many zero-shot math and reasoning prompts in a step-by-step style — but quality is uneven and below an instruction-tuned model of comparable size. Treat this "instruct" ability as a side effect of the pre-training mix, not a guaranteed capability.
36
+
37
+ The four single tags and their prefix tokens (for reference; you can compose any subset, comma-separated, in the order you want them emitted):
38
+
39
+ - `direct` → `<|object_ref_start|>` — direct answer, no CoT
40
+ - `cot` → `<|object_ref_end|>` — chain-of-thought
41
+ - `noisy` → `<|quad_start|>` — noisy / web-crawl style
42
+ - `synth` → `<|quad_end|>` — synthetic / curated style
43
+
44
+ ## Requirements
45
+
46
+ The `hrm_text` model class has been merged into Transformers `main`. The PyPI release containing it may still be in flight; until then, install Transformers directly from the upstream `main` branch:
47
+
48
+ ```bash
49
+ pip install --upgrade "git+https://github.com/huggingface/transformers.git@main"
50
+ ```
51
+
52
+ ## Model details
53
+
54
+ | Field | Value |
55
+ |---|---|
56
+ | Parameters | ~1 B |
57
+ | Hidden size | 1536 |
58
+ | Layers (per H / L stack) | 16 |
59
+ | Attention heads | 12 (MHA, head_dim 128) |
60
+ | H_cycles × L_cycles | 2 × 3 |
61
+ | Max sequence length | 4096 |
62
+ | Vocabulary | 65,536 |
63
+ | Embedding | Scaled (lecun_normal) |
64
+ | Position encoding | RoPE (theta 10000) |
65
+ | Activation | SwiGLU |
66
+ | Normalization | Parameterless Pre-RMSNorm |
67
+ | Attention | Gated (sigmoid output gate) |
68
+ | Training unique tokens | 40 B |
69
+ | Optimizer | AdamATan2 (beta 0.9 / 0.95, wd 0.1, EMA 0.9999) |
70
+ | LR | 2.2e-4 (warmup 2000 steps) |
71
+ | Global batch | 196,608 tokens |
72
+ | dtype | bfloat16 |
73
+
74
+ ## Usage
75
+
76
+ ```python
77
+ from transformers import AutoModelForCausalLM, AutoTokenizer
78
+ import torch
79
+
80
+ model_id = "sapientinc/HRM-Text-1B"
81
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
82
+ model = AutoModelForCausalLM.from_pretrained(
83
+ model_id,
84
+ dtype=torch.bfloat16,
85
+ trust_remote_code=True,
86
+ ).cuda().eval()
87
+
88
+ # synth,cot composite — reasoning / CoT style (see Disclaimer for other modes)
89
+ condition = "<|quad_end|><|object_ref_end|>"
90
+ prompt = f"<|im_start|>{condition}Explain why the sky is blue.<|im_end|>"
91
+
92
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
93
+ # Mark the prompt as a single bidirectional prefix block — see "PrefixLM mask" below.
94
+ inputs["token_type_ids"] = torch.ones_like(inputs["input_ids"])
95
+
96
+ with torch.no_grad():
97
+ out = model.generate(**inputs, max_new_tokens=256, do_sample=False)
98
+ print(tokenizer.decode(out[0], skip_special_tokens=False))
99
+ ```
100
+
101
+ ### PrefixLM mask — pass `token_type_ids`
102
+
103
+ HRM-Text was pre-trained with a PrefixLM mask: prompt tokens attend bidirectionally to each other, response tokens attend causally. To match the training-time forward at inference you must tell the model which positions are prefix.
104
+
105
+ In the current Transformers port the mask is controlled by `token_type_ids`:
106
+ - `token_type_ids[i] == 1` → position `i` is part of the prefix block (bidirectional within the block).
107
+ - otherwise → causal.
108
+
109
+ If you omit `token_type_ids`, attention falls back to **pure causal**, which does **not** match the pre-training distribution and will give noticeably worse logits. The simplest correct call passes `token_type_ids = torch.ones_like(input_ids)`, marking the entire input prompt as one bidirectional prefix block — exactly how training-time prefill ran.
110
+
111
+ ## Architecture
112
+
113
+ The recurrent core (per forward pass, in inference mode):
114
+
115
+ ```
116
+ z_H = embed(input_ids) * embedding_scale
117
+ z_L = z_L_init.expand_as(z_H)
118
+
119
+ for _ in range(H_cycles):
120
+ for _ in range(L_cycles):
121
+ z_L = L_module(z_L + z_H)
122
+ z_H = H_module(z_H + z_L)
123
+ return z_H
124
+ ```
125
+
126
+ Both stacks share the same Transformer block design (gated attention, RoPE, SwiGLU, pre-RMSNorm); see Model details above for shapes.
127
+
128
+ ## Training data
129
+
130
+ Pre-trained on a sampled mixture of publicly available text corpora. The full dataset composition, sampling weights, and preprocessing pipeline are open-sourced:
131
+
132
+ <p align="center">
133
+ <a href="https://github.com/sapientinc/data_io"><img alt="data_io" src="https://img.shields.io/badge/GitHub-sapientinc%2Fdata__io-181717?logo=github&logoColor=white"></a>
134
+ </p>
135
+
136
+ ## Limitations
137
+
138
+ - English only (training corpus is predominantly English).
139
+ - Outputs may be inaccurate, biased, or unsafe.
140
+
141
+ ## License
142
+
143
+ [Apache License 2.0](LICENSE).
144
+
145
+ ## Citation
146
+
147
+ Citation information will be added with the accompanying paper.
__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .configuration_hrm_text import *
15
+ from .modeling_hrm_text import *
banner.jpg ADDED

Git LFS Details

  • SHA256: 4ec52ca7bc19373cbf999451bf79c55c3ff09c3d1fd24a46c2f467b852abf420
  • Pointer size: 131 Bytes
  • Size of remote file: 516 kB
benchmark_scatter.png ADDED

Git LFS Details

  • SHA256: f0eddfa5c28e0069bfd0b2d50a8c9e4b2a7c720144f4ea9abb38105f77e2ba20
  • Pointer size: 131 Bytes
  • Size of remote file: 413 kB
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "hrm_text",
3
+ "architectures": [
4
+ "HrmTextForCausalLM"
5
+ ],
6
+ "vocab_size": 65536,
7
+ "hidden_size": 1536,
8
+ "intermediate_size": 4096,
9
+ "num_hidden_layers": 16,
10
+ "num_attention_heads": 12,
11
+ "num_key_value_heads": 12,
12
+ "head_dim": 128,
13
+ "H_cycles": 2,
14
+ "L_cycles": 3,
15
+ "L_bp_cycles": [
16
+ 2
17
+ ],
18
+ "max_position_embeddings": 4096,
19
+ "rms_norm_eps": 1e-06,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "initializer_range": 0.025515518153991442,
23
+ "embedding_scale": 39.191835884530846,
24
+ "prefix_lm": true,
25
+ "pad_token_id": 5,
26
+ "bos_token_id": 6,
27
+ "eos_token_id": 11,
28
+ "auto_map": {
29
+ "AutoConfig": "configuration_hrm_text.HrmTextConfig",
30
+ "AutoModel": "modeling_hrm_text.HrmTextModel",
31
+ "AutoModelForCausalLM": "modeling_hrm_text.HrmTextForCausalLM"
32
+ }
33
+ }
configuration_hrm_text.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_hrm_text.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from huggingface_hub.dataclasses import strict
22
+
23
+ from transformers.configuration_utils import PreTrainedConfig
24
+ from transformers.modeling_rope_utils import RopeParameters
25
+ from transformers.utils import auto_docstring
26
+ from transformers.utils.generic import is_flash_attention_requested, split_attention_implementation
27
+ from transformers.utils.type_validators import interval
28
+
29
+
30
+ @auto_docstring(checkpoint="sapientinc/HRM-Text-1B")
31
+ @strict
32
+ class HrmTextConfig(PreTrainedConfig):
33
+ r"""
34
+ H_cycles (`int`, *optional*, defaults to 2):
35
+ Number of high-level cycles.
36
+ L_cycles (`int`, *optional*, defaults to 3):
37
+ Number of low-level cycles per H-cycle.
38
+ L_bp_cycles (`list[int]`, *optional*, defaults to `[2]`):
39
+ Training-time gradient-routing list; left-padded with `1`s up to `L_cycles` inside the model.
40
+ Inference-time no-op.
41
+ embedding_scale (`float`, *optional*):
42
+ Token-embedding multiplier. If `None`, defaults to `1 / initializer_range`.
43
+ prefix_lm (`bool`, *optional*, defaults to `True`):
44
+ Instruction tokens attend bidirectionally, response tokens attend causally.
45
+ num_layers_per_stack (`int`, *optional*):
46
+ Real number of transformer blocks inside each
47
+ of the H / L stacks. Set automatically on first construction: the value passed as
48
+ `num_hidden_layers` is remembered here and `num_hidden_layers` is then rewritten to
49
+ `num_layers_per_stack * H_cycles * (L_cycles + 1)` so that
50
+ `DynamicCache(config=...)` pre-allocates one slot per unique attention invocation
51
+ under the recurrent forward. Do not set this directly on first construction — pass
52
+ the real per-stack count as `num_hidden_layers` and let `__post_init__` split it.
53
+ """
54
+
55
+ model_type = "hrm_text"
56
+ keys_to_ignore_at_inference = ["past_key_values"]
57
+
58
+ base_model_tp_plan = {
59
+ **{f"{stack}.layers.*.self_attn.q_proj": "colwise" for stack in ("L_module", "H_module")},
60
+ **{f"{stack}.layers.*.self_attn.k_proj": "colwise" for stack in ("L_module", "H_module")},
61
+ **{f"{stack}.layers.*.self_attn.v_proj": "colwise" for stack in ("L_module", "H_module")},
62
+ **{f"{stack}.layers.*.self_attn.gate_proj": "colwise" for stack in ("L_module", "H_module")},
63
+ **{f"{stack}.layers.*.self_attn.o_proj": "rowwise" for stack in ("L_module", "H_module")},
64
+ **{f"{stack}.layers.*.mlp.gate_proj": "colwise" for stack in ("L_module", "H_module")},
65
+ **{f"{stack}.layers.*.mlp.up_proj": "colwise" for stack in ("L_module", "H_module")},
66
+ **{f"{stack}.layers.*.mlp.down_proj": "rowwise" for stack in ("L_module", "H_module")},
67
+ }
68
+ base_model_pp_plan = {
69
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
70
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
71
+ "norm": (["hidden_states"], ["hidden_states"]),
72
+ }
73
+
74
+ vocab_size: int = 151808
75
+ hidden_size: int = 1536
76
+ intermediate_size: int = 4096
77
+ num_hidden_layers: int = 16
78
+ num_attention_heads: int = 12
79
+ hidden_act: str = "silu"
80
+ max_position_embeddings: int = 2048
81
+ initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
82
+ rms_norm_eps: float = 1e-6
83
+ use_cache: bool = True
84
+ pad_token_id: int | None = None
85
+ bos_token_id: int | None = None
86
+ eos_token_id: int | list[int] | None = None
87
+ tie_word_embeddings: bool = False
88
+ rope_parameters: RopeParameters | dict | None = None
89
+ attention_bias: bool = False
90
+ attention_dropout: int | float | None = 0.0
91
+ mlp_bias: bool = False
92
+ head_dim: int = 128
93
+
94
+ H_cycles: int = 2
95
+ L_cycles: int = 3
96
+ L_bp_cycles: list[int] | None = None
97
+ embedding_scale: float | None = None
98
+ prefix_lm: bool = True
99
+ num_layers_per_stack: int | None = None # Usually inferred in post init
100
+
101
+ def __post_init__(self, **kwargs):
102
+ if self.L_bp_cycles is None:
103
+ # Default `[2]` = backprop only the last 2 L-iterations per H-cycle (training-time
104
+ # gradient-routing knob). Left-padding to length `L_cycles` is performed inside
105
+ # [`HrmTextModel`] since it depends on `L_cycles`.
106
+ self.L_bp_cycles = [2]
107
+
108
+ if self.embedding_scale is None:
109
+ self.embedding_scale = 1.0 / self.initializer_range
110
+
111
+ if self.num_layers_per_stack is None:
112
+ # Initial construction, or legacy checkpoint where `num_hidden_layers` carries the
113
+ # real per-stack count: remember that value and rewrite `num_hidden_layers` to the
114
+ # inflated total, so standard HF cache allocation gives us one slot per unique
115
+ # attention invocation. Serialised configs round-trip as (inflated, real) pairs.
116
+ self.num_layers_per_stack = self.num_hidden_layers
117
+ self.num_hidden_layers = self.num_layers_per_stack * self.H_cycles * (self.L_cycles + 1)
118
+
119
+ super().__post_init__(**kwargs)
120
+
121
+ def validate_architecture(self):
122
+ """Part of `@strict`-powered validation. Validates the architecture of the config."""
123
+ if self.hidden_size % self.num_attention_heads != 0:
124
+ raise ValueError(
125
+ f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
126
+ f"heads ({self.num_attention_heads})."
127
+ )
128
+
129
+ @property
130
+ def _attn_implementation(self):
131
+ return self._attn_implementation_internal
132
+
133
+ @_attn_implementation.setter
134
+ def _attn_implementation(self, value: str | dict | None):
135
+ if value is not None and self.prefix_lm:
136
+ _, base_implementation = split_attention_implementation(value)
137
+ if is_flash_attention_requested(requested_attention_implementation=base_implementation):
138
+ raise ValueError(
139
+ f"`attn_implementation={value!r}` is not supported when "
140
+ "`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
141
+ "overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
142
+ )
143
+ PreTrainedConfig._attn_implementation.__set__(self, value)
144
+
145
+
146
+ __all__ = ["HrmTextConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8fe2b2bf6948414e8e8d6538659198726d98f967c55b533b7aabe8a1fa9a584
3
+ size 2365606568
modeling_hrm_text.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_hrm_text.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+ from contextlib import nullcontext
23
+ from typing import Optional
24
+
25
+ import torch
26
+ from torch import nn
27
+
28
+ from transformers import initialization as init
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.configuration_utils import PreTrainedConfig
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.integrations import use_kernel_func_from_hub, use_kernelized_func
34
+ from transformers.masking_utils import create_causal_mask, create_masks_for_generate
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import auto_docstring, can_return_tuple, logging
41
+ from transformers.utils.generic import (
42
+ TransformersKwargs,
43
+ is_flash_attention_requested,
44
+ maybe_autocast,
45
+ merge_with_config_defaults,
46
+ split_attention_implementation,
47
+ )
48
+ from transformers.utils.output_capturing import capture_outputs
49
+ from .configuration_hrm_text import HrmTextConfig
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ class HrmTextRMSNorm(torch.nn.Module):
56
+ def __init__(self, eps: float = 1e-6):
57
+ super().__init__()
58
+ self.eps = eps
59
+
60
+ def _norm(self, x):
61
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
62
+
63
+ def forward(self, x):
64
+ return self._norm(x.float()).type_as(x)
65
+
66
+ def extra_repr(self):
67
+ return f"eps={self.eps}"
68
+
69
+
70
+ class HrmTextMLP(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.config = config
74
+ self.hidden_size = config.hidden_size
75
+ self.intermediate_size = config.intermediate_size
76
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
77
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
78
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
79
+ self.act_fn = ACT2FN[config.hidden_act]
80
+
81
+ def forward(self, x):
82
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
83
+ return down_proj
84
+
85
+
86
+ def rotate_half(x):
87
+ """Rotates half the hidden dims of the input."""
88
+ x1 = x[..., : x.shape[-1] // 2]
89
+ x2 = x[..., x.shape[-1] // 2 :]
90
+ return torch.cat((-x2, x1), dim=-1)
91
+
92
+
93
+ @use_kernel_func_from_hub("rotary_pos_emb")
94
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
95
+ """Applies Rotary Position Embedding to the query and key tensors.
96
+
97
+ Args:
98
+ q (`torch.Tensor`): The query tensor.
99
+ k (`torch.Tensor`): The key tensor.
100
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
101
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
102
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
103
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
104
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
105
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
106
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
107
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
108
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
109
+ Returns:
110
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
111
+ """
112
+ cos = cos.unsqueeze(unsqueeze_dim)
113
+ sin = sin.unsqueeze(unsqueeze_dim)
114
+ q_embed = (q * cos) + (rotate_half(q) * sin)
115
+ k_embed = (k * cos) + (rotate_half(k) * sin)
116
+ return q_embed, k_embed
117
+
118
+
119
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
120
+ """
121
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
122
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
123
+ """
124
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
125
+ if n_rep == 1:
126
+ return hidden_states
127
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
128
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
129
+
130
+
131
+ def eager_attention_forward(
132
+ module: nn.Module,
133
+ query: torch.Tensor,
134
+ key: torch.Tensor,
135
+ value: torch.Tensor,
136
+ attention_mask: torch.Tensor | None,
137
+ scaling: float,
138
+ dropout: float = 0.0,
139
+ **kwargs: Unpack[TransformersKwargs],
140
+ ):
141
+ key_states = repeat_kv(key, module.num_key_value_groups)
142
+ value_states = repeat_kv(value, module.num_key_value_groups)
143
+
144
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
145
+ if attention_mask is not None:
146
+ attn_weights = attn_weights + attention_mask
147
+
148
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
149
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
150
+ attn_output = torch.matmul(attn_weights, value_states)
151
+ attn_output = attn_output.transpose(1, 2).contiguous()
152
+
153
+ return attn_output, attn_weights
154
+
155
+
156
+ @use_kernelized_func(apply_rotary_pos_emb)
157
+ class HrmTextAttention(nn.Module):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(self, config: HrmTextConfig, layer_idx: int):
161
+ super().__init__()
162
+ self.config = config
163
+ self.layer_idx = layer_idx
164
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
+ self.num_key_value_groups = 1 # Uses MHA instead of GQA
166
+ self.scaling = self.head_dim**-0.5
167
+ self.attention_dropout = config.attention_dropout
168
+ self.is_causal = True
169
+
170
+ self.q_proj = nn.Linear(
171
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
172
+ )
173
+ self.k_proj = nn.Linear(
174
+ config.hidden_size,
175
+ config.num_attention_heads * self.head_dim,
176
+ bias=config.attention_bias,
177
+ )
178
+ self.v_proj = nn.Linear(
179
+ config.hidden_size,
180
+ config.num_attention_heads * self.head_dim,
181
+ bias=config.attention_bias,
182
+ )
183
+ self.o_proj = nn.Linear(
184
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
185
+ )
186
+ # Additional sigmoid gate applied at the end
187
+ self.gate_proj = nn.Linear(
188
+ config.hidden_size,
189
+ config.num_attention_heads * self.head_dim,
190
+ bias=config.attention_bias,
191
+ )
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
197
+ attention_mask: torch.Tensor | None = None,
198
+ past_key_values: Cache | None = None,
199
+ cycle_offset: int = 0,
200
+ **kwargs: Unpack[TransformersKwargs],
201
+ ) -> tuple[torch.Tensor, torch.Tensor]:
202
+ input_shape = hidden_states.shape[:-1]
203
+ hidden_shape = (*input_shape, -1, self.head_dim)
204
+
205
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
206
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
207
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
208
+ gate_states = self.gate_proj(hidden_states).view(hidden_shape)
209
+
210
+ cos, sin = position_embeddings
211
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
212
+
213
+ if past_key_values is not None:
214
+ # Adjust cache slot by `cycle_offset` which is determined by it's current recurrent step through the stacks
215
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx + cycle_offset)
216
+
217
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
218
+ self.config._attn_implementation, eager_attention_forward
219
+ )
220
+ attn_output, attn_weights = attention_interface(
221
+ self,
222
+ query_states,
223
+ key_states,
224
+ value_states,
225
+ attention_mask,
226
+ dropout=0.0 if not self.training else self.attention_dropout,
227
+ scaling=self.scaling,
228
+ **kwargs,
229
+ )
230
+
231
+ # Additional sigmoid gating (similar to Qwen3Next)
232
+ attn_output = torch.sigmoid(gate_states) * attn_output
233
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
234
+ attn_output = self.o_proj(attn_output)
235
+ return attn_output, attn_weights
236
+
237
+
238
+ class HrmTextDecoderLayer(GradientCheckpointingLayer):
239
+ def __init__(self, config: HrmTextConfig, layer_idx: int):
240
+ super().__init__()
241
+ self.hidden_size = config.hidden_size
242
+
243
+ self.self_attn = HrmTextAttention(config=config, layer_idx=layer_idx)
244
+
245
+ self.mlp = HrmTextMLP(config)
246
+ self.input_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
247
+ self.post_attention_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ attention_mask: torch.Tensor | None = None,
253
+ position_ids: torch.LongTensor | None = None,
254
+ past_key_values: Cache | None = None,
255
+ use_cache: bool | None = False,
256
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
257
+ **kwargs: Unpack[TransformersKwargs],
258
+ ) -> torch.Tensor:
259
+ residual = hidden_states
260
+ hidden_states = self.input_layernorm(hidden_states)
261
+ # Self Attention
262
+ hidden_states, _ = self.self_attn(
263
+ hidden_states=hidden_states,
264
+ attention_mask=attention_mask,
265
+ position_ids=position_ids,
266
+ past_key_values=past_key_values,
267
+ use_cache=use_cache,
268
+ position_embeddings=position_embeddings,
269
+ **kwargs,
270
+ )
271
+ hidden_states = residual + hidden_states
272
+
273
+ # Fully Connected
274
+ residual = hidden_states
275
+ hidden_states = self.post_attention_layernorm(hidden_states)
276
+ hidden_states = self.mlp(hidden_states)
277
+ hidden_states = residual + hidden_states
278
+ return hidden_states
279
+
280
+
281
+ class HrmTextStack(nn.Module):
282
+ """A single transformer stack — used twice inside, once as H module and once as L module"""
283
+
284
+ def __init__(self, config: HrmTextConfig):
285
+ super().__init__()
286
+ self.layers = nn.ModuleList(
287
+ [HrmTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers_per_stack)]
288
+ )
289
+ self.final_norm = HrmTextRMSNorm(eps=config.rms_norm_eps)
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ attention_mask: torch.Tensor | None = None,
295
+ past_key_values: Cache | None = None,
296
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
297
+ cycle_offset: int = 0,
298
+ **kwargs: Unpack[TransformersKwargs],
299
+ ) -> torch.Tensor:
300
+ for layer in self.layers:
301
+ hidden_states = layer(
302
+ hidden_states,
303
+ attention_mask=attention_mask,
304
+ past_key_values=past_key_values,
305
+ position_embeddings=position_embeddings,
306
+ cycle_offset=cycle_offset,
307
+ **kwargs,
308
+ )
309
+ return self.final_norm(hidden_states)
310
+
311
+
312
+ @auto_docstring
313
+ class HrmTextPreTrainedModel(PreTrainedModel):
314
+ config: HrmTextConfig
315
+ base_model_prefix = "model"
316
+ supports_gradient_checkpointing = True
317
+ _no_split_modules = ["HrmTextDecoderLayer"]
318
+ _skip_keys_device_placement = ["past_key_values"]
319
+ _supports_flash_attn = True
320
+ _supports_sdpa = True
321
+ _supports_flex_attn = True
322
+
323
+ _can_compile_fullgraph = True
324
+ _supports_attention_backend = True
325
+ _can_record_outputs = {
326
+ "hidden_states": HrmTextDecoderLayer,
327
+ "attentions": HrmTextAttention,
328
+ }
329
+
330
+ def _check_and_adjust_attn_implementation(
331
+ self, attn_implementation: str | None, is_init_check: bool = False, allow_all_kernels: bool = False
332
+ ) -> str:
333
+ if attn_implementation is not None and self.config.prefix_lm:
334
+ _, base_implementation = split_attention_implementation(attn_implementation)
335
+ if is_flash_attention_requested(requested_attention_implementation=base_implementation):
336
+ raise ValueError(
337
+ f"`attn_implementation={attn_implementation!r}` is not supported when "
338
+ "`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
339
+ "overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
340
+ )
341
+ return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check, allow_all_kernels)
342
+
343
+ @torch.no_grad()
344
+ def _init_weights(self, module):
345
+ super()._init_weights(module)
346
+ if isinstance(module, HrmTextModel):
347
+ init.zeros_(module.z_L_init)
348
+ # `z_L_init` is the frozen low-cycle initial state and never trains.
349
+ module.z_L_init.requires_grad_(False) # trf-ignore: TRF012
350
+
351
+
352
+ class HrmTextRotaryEmbedding(nn.Module):
353
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
354
+
355
+ def __init__(self, config: HrmTextConfig, device=None):
356
+ super().__init__()
357
+ self.max_seq_len_cached = config.max_position_embeddings
358
+ self.original_max_seq_len = config.max_position_embeddings
359
+
360
+ self.config = config
361
+
362
+ self.rope_type = self.config.rope_parameters["rope_type"]
363
+ rope_init_fn: Callable = self.compute_default_rope_parameters
364
+ if self.rope_type != "default":
365
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
366
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
367
+
368
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
369
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
370
+
371
+ @staticmethod
372
+ def compute_default_rope_parameters(
373
+ config: HrmTextConfig | None = None,
374
+ device: Optional["torch.device"] = None,
375
+ seq_len: int | None = None,
376
+ ) -> tuple["torch.Tensor", float]:
377
+ """
378
+ Computes the inverse frequencies according to the original RoPE implementation
379
+ Args:
380
+ config ([`~transformers.PreTrainedConfig`]):
381
+ The model configuration.
382
+ device (`torch.device`):
383
+ The device to use for initialization of the inverse frequencies.
384
+ seq_len (`int`, *optional*):
385
+ The current sequence length. Unused for this type of RoPE.
386
+ Returns:
387
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
388
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
389
+ """
390
+ base = config.rope_parameters["rope_theta"]
391
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
392
+
393
+ attention_factor = 1.0 # Unused in this type of RoPE
394
+
395
+ # Compute the inverse frequencies
396
+ inv_freq = 1.0 / (
397
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
398
+ )
399
+ return inv_freq, attention_factor
400
+
401
+ @torch.no_grad()
402
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
403
+ def forward(self, x, position_ids):
404
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
405
+ position_ids_expanded = position_ids[:, None, :].float()
406
+
407
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
408
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
409
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
410
+ emb = torch.cat((freqs, freqs), dim=-1)
411
+ cos = emb.cos() * self.attention_scaling
412
+ sin = emb.sin() * self.attention_scaling
413
+
414
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
415
+
416
+
417
+ @auto_docstring
418
+ class HrmTextModel(HrmTextPreTrainedModel):
419
+ def __init__(self, config: HrmTextConfig):
420
+ super().__init__(config)
421
+ self.padding_idx = config.pad_token_id
422
+ self.vocab_size = config.vocab_size
423
+
424
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
425
+ self.rotary_emb = HrmTextRotaryEmbedding(config=config)
426
+ self.gradient_checkpointing = False
427
+
428
+ self.embedding_scale = config.embedding_scale
429
+
430
+ # Recursive module structures
431
+ self.L_module = HrmTextStack(config)
432
+ self.H_module = HrmTextStack(config)
433
+ # Initial state for the low cycle module
434
+ self.z_L_init = nn.Parameter(torch.zeros(config.hidden_size), requires_grad=False)
435
+
436
+ raw_bp = list(config.L_bp_cycles)
437
+ self.L_bp_cycles_padded = [1] * max(0, config.L_cycles - len(raw_bp)) + raw_bp
438
+
439
+ # Initialize weights and apply final processing
440
+ self.post_init()
441
+
442
+ @merge_with_config_defaults
443
+ @capture_outputs
444
+ @auto_docstring
445
+ def forward(
446
+ self,
447
+ input_ids: torch.LongTensor | None = None,
448
+ attention_mask: torch.Tensor | None = None,
449
+ position_ids: torch.LongTensor | None = None,
450
+ past_key_values: Cache | None = None,
451
+ token_type_ids: torch.LongTensor | None = None,
452
+ inputs_embeds: torch.FloatTensor | None = None,
453
+ use_cache: bool | None = None,
454
+ **kwargs: Unpack[TransformersKwargs],
455
+ ) -> BaseModelOutputWithPast:
456
+ r"""
457
+ token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
458
+ Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
459
+ form a single bidirectional block; all other positions are causal.
460
+ """
461
+ if (input_ids is None) ^ (inputs_embeds is not None):
462
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
463
+
464
+ if inputs_embeds is None:
465
+ inputs_embeds = self.embed_tokens(input_ids)
466
+ # Additional scaling on the input embeds
467
+ inputs_embeds = inputs_embeds * self.embedding_scale
468
+
469
+ if use_cache and past_key_values is None:
470
+ past_key_values = DynamicCache(config=self.config)
471
+
472
+ if position_ids is None:
473
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
474
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
475
+ position_ids = position_ids.unsqueeze(0)
476
+
477
+ # Create mask with optional prefix-based bidirectionality
478
+ mask_kwargs = {
479
+ "config": self.config,
480
+ "inputs_embeds": inputs_embeds,
481
+ "attention_mask": attention_mask,
482
+ "past_key_values": past_key_values,
483
+ "position_ids": position_ids,
484
+ }
485
+ is_first_iteration = past_key_values is None or not past_key_values.is_initialized
486
+ if token_type_ids is not None and is_first_iteration:
487
+ if self.config.prefix_lm:
488
+ mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
489
+ else:
490
+ logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
491
+
492
+ attention_mask = create_causal_mask(**mask_kwargs)
493
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
494
+
495
+ # Hierarchical (H/L)-cycle recurrence
496
+ #
497
+ # `z_H` - slow / high-level state
498
+ hidden_states_high_cycle = inputs_embeds
499
+ # `z_L` - fast / low-level state
500
+ hidden_states_low_cycle = (
501
+ self.z_L_init.to(dtype=hidden_states_high_cycle.dtype, device=hidden_states_high_cycle.device)
502
+ .expand_as(hidden_states_high_cycle)
503
+ .contiguous()
504
+ )
505
+
506
+ # Cache-slot layout under the recurrent forward:
507
+ #
508
+ # slot(h, l, layer) = (h * (L_cycles + 1) + l) * num_layers_per_stack + layer
509
+ # ^— L-stack invocation at (h, l)
510
+ # slot(h, H, layer) = (h * (L_cycles + 1) + L_cycles) * num_layers_per_stack + layer
511
+ # ^— trailing H-stack invocation
512
+ #
513
+ # That totals `num_layers_per_stack * H_cycles * (L_cycles + 1)` slots, i.e. the `config.num_hidden_layers`.
514
+ num_layers_per_stack = self.config.num_layers_per_stack
515
+ for high_cycle_idx in range(self.config.H_cycles):
516
+ # `L_bp_cycles` k-step grad trick: only the trailing `num_grad_iterations` of the
517
+ # `L_cycles` inner iterations propagate gradients; earlier iterations run under
518
+ # `torch.no_grad()` to bound activation memory.
519
+ num_grad_iterations = (
520
+ self.L_bp_cycles_padded[high_cycle_idx] if high_cycle_idx < len(self.L_bp_cycles_padded) else 1
521
+ )
522
+ grad_threshold = self.config.L_cycles - num_grad_iterations
523
+ for low_cycle_idx in range(self.config.L_cycles):
524
+ cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + low_cycle_idx) * num_layers_per_stack
525
+ ctx = nullcontext() if low_cycle_idx >= grad_threshold else torch.no_grad()
526
+ with ctx:
527
+ hidden_states_low_cycle = self.L_module(
528
+ hidden_states_low_cycle.to(hidden_states_high_cycle.device) + hidden_states_high_cycle,
529
+ attention_mask=attention_mask,
530
+ past_key_values=past_key_values,
531
+ position_embeddings=position_embeddings,
532
+ position_ids=position_ids,
533
+ cycle_offset=cycle_offset,
534
+ **kwargs,
535
+ )
536
+
537
+ cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + self.config.L_cycles) * num_layers_per_stack
538
+
539
+ hidden_states_high_cycle = self.H_module(
540
+ hidden_states_high_cycle + hidden_states_low_cycle.to(hidden_states_high_cycle.device),
541
+ attention_mask=attention_mask,
542
+ past_key_values=past_key_values,
543
+ position_embeddings=position_embeddings,
544
+ position_ids=position_ids,
545
+ cycle_offset=cycle_offset,
546
+ **kwargs,
547
+ )
548
+
549
+ return BaseModelOutputWithPast(
550
+ last_hidden_state=hidden_states_high_cycle,
551
+ past_key_values=past_key_values,
552
+ )
553
+
554
+
555
+ @auto_docstring
556
+ class HrmTextForCausalLM(HrmTextPreTrainedModel, GenerationMixin):
557
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
558
+ _tp_plan = {"lm_head": "colwise_gather_output"}
559
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
560
+
561
+ def __init__(self, config):
562
+ super().__init__(config)
563
+ self.model = HrmTextModel(config)
564
+ self.vocab_size = config.vocab_size
565
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
566
+
567
+ # Initialize weights and apply final processing
568
+ self.post_init()
569
+
570
+ @can_return_tuple
571
+ @auto_docstring
572
+ def forward(
573
+ self,
574
+ input_ids: torch.LongTensor | None = None,
575
+ attention_mask: torch.Tensor | None = None,
576
+ position_ids: torch.LongTensor | None = None,
577
+ past_key_values: Cache | None = None,
578
+ token_type_ids: torch.LongTensor | None = None,
579
+ inputs_embeds: torch.FloatTensor | None = None,
580
+ labels: torch.LongTensor | None = None,
581
+ use_cache: bool | None = None,
582
+ logits_to_keep: int | torch.Tensor = 0,
583
+ **kwargs: Unpack[TransformersKwargs],
584
+ ) -> CausalLMOutputWithPast:
585
+ r"""
586
+ token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
587
+ Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
588
+ form a single bidirectional block; all other positions are causal.
589
+ """
590
+ outputs: BaseModelOutputWithPast = self.model(
591
+ input_ids=input_ids,
592
+ attention_mask=attention_mask,
593
+ position_ids=position_ids,
594
+ past_key_values=past_key_values,
595
+ token_type_ids=token_type_ids,
596
+ inputs_embeds=inputs_embeds,
597
+ use_cache=use_cache,
598
+ **kwargs,
599
+ )
600
+
601
+ hidden_states = outputs.last_hidden_state
602
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
603
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
604
+
605
+ loss = None
606
+ if labels is not None:
607
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
608
+
609
+ return CausalLMOutputWithPast(
610
+ loss=loss,
611
+ logits=logits,
612
+ past_key_values=outputs.past_key_values,
613
+ hidden_states=outputs.hidden_states,
614
+ attentions=outputs.attentions,
615
+ )
616
+
617
+ @staticmethod
618
+ def create_masks_for_generate(
619
+ config: PreTrainedConfig,
620
+ inputs_embeds: torch.Tensor,
621
+ attention_mask: torch.Tensor | None,
622
+ past_key_values: Cache | None,
623
+ position_ids: torch.Tensor | None,
624
+ token_type_ids: torch.Tensor | None = None,
625
+ is_first_iteration: bool | None = False,
626
+ **kwargs,
627
+ ) -> dict:
628
+ mask_kwargs = {
629
+ "config": config,
630
+ "inputs_embeds": inputs_embeds,
631
+ "attention_mask": attention_mask,
632
+ "past_key_values": past_key_values,
633
+ "position_ids": position_ids,
634
+ }
635
+ if token_type_ids is not None and is_first_iteration:
636
+ if config.prefix_lm:
637
+ mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
638
+ else:
639
+ logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
640
+
641
+ return create_masks_for_generate(**mask_kwargs)
642
+
643
+
644
+ __all__ = ["HrmTextForCausalLM", "HrmTextModel", "HrmTextPreTrainedModel"]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|im_start|>",
5
+ "eos_token": "<|box_end|>",
6
+ "is_local": true,
7
+ "local_files_only": false,
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "Qwen2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }