abcd1927 commited on
Commit
9f082d6
·
1 Parent(s): 1f82ac2

Switch to native transformers hrm_text support

Browse files

transformers 5.9.0 ships native HrmTextForCausalLM. Drop the custom
modeling code and trust_remote_code path: bump install requirement to
>=5.9.0, remove auto_map from config, and delete the Python sources.

Files changed (5) hide show
  1. README.md +2 -3
  2. __init__.py +0 -15
  3. config.json +1 -6
  4. configuration_hrm_text.py +0 -146
  5. modeling_hrm_text.py +0 -644
README.md CHANGED
@@ -46,10 +46,10 @@ The four single condition tags and their assigned tokenizer special tokens (toke
46
 
47
  ## Requirements
48
 
49
- Use a Transformers build that includes the `hrm_text` model class. If your installed release does not include it yet, install Transformers directly from the upstream `main` branch:
50
 
51
  ```bash
52
- pip install --upgrade "git+https://github.com/huggingface/transformers.git@main"
53
  ```
54
 
55
  ## Model details
@@ -85,7 +85,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
85
  model = AutoModelForCausalLM.from_pretrained(
86
  model_id,
87
  dtype=torch.bfloat16,
88
- trust_remote_code=True,
89
  ).cuda().eval()
90
 
91
  # synth,cot composite — reasoning / CoT style (see Disclaimer for other modes)
 
46
 
47
  ## Requirements
48
 
49
+ Requires `transformers >= 5.9.0`, which ships native support for the `hrm_text` model class:
50
 
51
  ```bash
52
+ pip install --upgrade "transformers>=5.9.0"
53
  ```
54
 
55
  ## Model details
 
85
  model = AutoModelForCausalLM.from_pretrained(
86
  model_id,
87
  dtype=torch.bfloat16,
 
88
  ).cuda().eval()
89
 
90
  # synth,cot composite — reasoning / CoT style (see Disclaimer for other modes)
__init__.py DELETED
@@ -1,15 +0,0 @@
1
- # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from .configuration_hrm_text import *
15
- from .modeling_hrm_text import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -25,10 +25,5 @@
25
  "prefix_lm": true,
26
  "pad_token_id": 5,
27
  "bos_token_id": 6,
28
- "eos_token_id": 11,
29
- "auto_map": {
30
- "AutoConfig": "configuration_hrm_text.HrmTextConfig",
31
- "AutoModel": "modeling_hrm_text.HrmTextModel",
32
- "AutoModelForCausalLM": "modeling_hrm_text.HrmTextForCausalLM"
33
- }
34
  }
 
25
  "prefix_lm": true,
26
  "pad_token_id": 5,
27
  "bos_token_id": 6,
28
+ "eos_token_id": 11
 
 
 
 
 
29
  }
configuration_hrm_text.py DELETED
@@ -1,146 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_hrm_text.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
-
21
- from huggingface_hub.dataclasses import strict
22
-
23
- from transformers.configuration_utils import PreTrainedConfig
24
- from transformers.modeling_rope_utils import RopeParameters
25
- from transformers.utils import auto_docstring
26
- from transformers.utils.generic import is_flash_attention_requested, split_attention_implementation
27
- from transformers.utils.type_validators import interval
28
-
29
-
30
- @auto_docstring(checkpoint="sapientinc/HRM-Text-1B")
31
- @strict
32
- class HrmTextConfig(PreTrainedConfig):
33
- r"""
34
- H_cycles (`int`, *optional*, defaults to 2):
35
- Number of high-level cycles.
36
- L_cycles (`int`, *optional*, defaults to 3):
37
- Number of low-level cycles per H-cycle.
38
- L_bp_cycles (`list[int]`, *optional*, defaults to `[2]`):
39
- Training-time gradient-routing list; left-padded with `1`s up to `L_cycles` inside the model.
40
- Inference-time no-op.
41
- embedding_scale (`float`, *optional*):
42
- Token-embedding multiplier. If `None`, defaults to `1 / initializer_range`.
43
- prefix_lm (`bool`, *optional*, defaults to `True`):
44
- Instruction tokens attend bidirectionally, response tokens attend causally.
45
- num_layers_per_stack (`int`, *optional*):
46
- Real number of transformer blocks inside each
47
- of the H / L stacks. Set automatically on first construction: the value passed as
48
- `num_hidden_layers` is remembered here and `num_hidden_layers` is then rewritten to
49
- `num_layers_per_stack * H_cycles * (L_cycles + 1)` so that
50
- `DynamicCache(config=...)` pre-allocates one slot per unique attention invocation
51
- under the recurrent forward. Do not set this directly on first construction — pass
52
- the real per-stack count as `num_hidden_layers` and let `__post_init__` split it.
53
- """
54
-
55
- model_type = "hrm_text"
56
- keys_to_ignore_at_inference = ["past_key_values"]
57
-
58
- base_model_tp_plan = {
59
- **{f"{stack}.layers.*.self_attn.q_proj": "colwise" for stack in ("L_module", "H_module")},
60
- **{f"{stack}.layers.*.self_attn.k_proj": "colwise" for stack in ("L_module", "H_module")},
61
- **{f"{stack}.layers.*.self_attn.v_proj": "colwise" for stack in ("L_module", "H_module")},
62
- **{f"{stack}.layers.*.self_attn.gate_proj": "colwise" for stack in ("L_module", "H_module")},
63
- **{f"{stack}.layers.*.self_attn.o_proj": "rowwise" for stack in ("L_module", "H_module")},
64
- **{f"{stack}.layers.*.mlp.gate_proj": "colwise" for stack in ("L_module", "H_module")},
65
- **{f"{stack}.layers.*.mlp.up_proj": "colwise" for stack in ("L_module", "H_module")},
66
- **{f"{stack}.layers.*.mlp.down_proj": "rowwise" for stack in ("L_module", "H_module")},
67
- }
68
- base_model_pp_plan = {
69
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
70
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
71
- "norm": (["hidden_states"], ["hidden_states"]),
72
- }
73
-
74
- vocab_size: int = 151808
75
- hidden_size: int = 1536
76
- intermediate_size: int = 4096
77
- num_hidden_layers: int = 16
78
- num_attention_heads: int = 12
79
- hidden_act: str = "silu"
80
- max_position_embeddings: int = 2048
81
- initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
82
- rms_norm_eps: float = 1e-6
83
- use_cache: bool = True
84
- pad_token_id: int | None = None
85
- bos_token_id: int | None = None
86
- eos_token_id: int | list[int] | None = None
87
- tie_word_embeddings: bool = False
88
- rope_parameters: RopeParameters | dict | None = None
89
- attention_bias: bool = False
90
- attention_dropout: int | float | None = 0.0
91
- mlp_bias: bool = False
92
- head_dim: int = 128
93
-
94
- H_cycles: int = 2
95
- L_cycles: int = 3
96
- L_bp_cycles: list[int] | None = None
97
- embedding_scale: float | None = None
98
- prefix_lm: bool = True
99
- num_layers_per_stack: int | None = None # Usually inferred in post init
100
-
101
- def __post_init__(self, **kwargs):
102
- if self.L_bp_cycles is None:
103
- # Default `[2]` = backprop only the last 2 L-iterations per H-cycle (training-time
104
- # gradient-routing knob). Left-padding to length `L_cycles` is performed inside
105
- # [`HrmTextModel`] since it depends on `L_cycles`.
106
- self.L_bp_cycles = [2]
107
-
108
- if self.embedding_scale is None:
109
- self.embedding_scale = 1.0 / self.initializer_range
110
-
111
- if self.num_layers_per_stack is None:
112
- # Initial construction, or legacy checkpoint where `num_hidden_layers` carries the
113
- # real per-stack count: remember that value and rewrite `num_hidden_layers` to the
114
- # inflated total, so standard HF cache allocation gives us one slot per unique
115
- # attention invocation. Serialised configs round-trip as (inflated, real) pairs.
116
- self.num_layers_per_stack = self.num_hidden_layers
117
- self.num_hidden_layers = self.num_layers_per_stack * self.H_cycles * (self.L_cycles + 1)
118
-
119
- super().__post_init__(**kwargs)
120
-
121
- def validate_architecture(self):
122
- """Part of `@strict`-powered validation. Validates the architecture of the config."""
123
- if self.hidden_size % self.num_attention_heads != 0:
124
- raise ValueError(
125
- f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
126
- f"heads ({self.num_attention_heads})."
127
- )
128
-
129
- @property
130
- def _attn_implementation(self):
131
- return self._attn_implementation_internal
132
-
133
- @_attn_implementation.setter
134
- def _attn_implementation(self, value: str | dict | None):
135
- if value is not None and self.prefix_lm:
136
- _, base_implementation = split_attention_implementation(value)
137
- if is_flash_attention_requested(requested_attention_implementation=base_implementation):
138
- raise ValueError(
139
- f"`attn_implementation={value!r}` is not supported when "
140
- "`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
141
- "overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
142
- )
143
- PreTrainedConfig._attn_implementation.__set__(self, value)
144
-
145
-
146
- __all__ = ["HrmTextConfig"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_hrm_text.py DELETED
@@ -1,644 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_hrm_text.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
-
21
- from collections.abc import Callable
22
- from contextlib import nullcontext
23
- from typing import Optional
24
-
25
- import torch
26
- from torch import nn
27
-
28
- from transformers import initialization as init
29
- from transformers.activations import ACT2FN
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.configuration_utils import PreTrainedConfig
32
- from transformers.generation import GenerationMixin
33
- from transformers.integrations import use_kernel_func_from_hub, use_kernelized_func
34
- from transformers.masking_utils import create_causal_mask, create_masks_for_generate
35
- from transformers.modeling_layers import GradientCheckpointingLayer
36
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
- from transformers.processing_utils import Unpack
40
- from transformers.utils import auto_docstring, can_return_tuple, logging
41
- from transformers.utils.generic import (
42
- TransformersKwargs,
43
- is_flash_attention_requested,
44
- maybe_autocast,
45
- merge_with_config_defaults,
46
- split_attention_implementation,
47
- )
48
- from transformers.utils.output_capturing import capture_outputs
49
- from .configuration_hrm_text import HrmTextConfig
50
-
51
-
52
- logger = logging.get_logger(__name__)
53
-
54
-
55
- class HrmTextRMSNorm(torch.nn.Module):
56
- def __init__(self, eps: float = 1e-6):
57
- super().__init__()
58
- self.eps = eps
59
-
60
- def _norm(self, x):
61
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
62
-
63
- def forward(self, x):
64
- return self._norm(x.float()).type_as(x)
65
-
66
- def extra_repr(self):
67
- return f"eps={self.eps}"
68
-
69
-
70
- class HrmTextMLP(nn.Module):
71
- def __init__(self, config):
72
- super().__init__()
73
- self.config = config
74
- self.hidden_size = config.hidden_size
75
- self.intermediate_size = config.intermediate_size
76
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
77
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
78
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
79
- self.act_fn = ACT2FN[config.hidden_act]
80
-
81
- def forward(self, x):
82
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
83
- return down_proj
84
-
85
-
86
- def rotate_half(x):
87
- """Rotates half the hidden dims of the input."""
88
- x1 = x[..., : x.shape[-1] // 2]
89
- x2 = x[..., x.shape[-1] // 2 :]
90
- return torch.cat((-x2, x1), dim=-1)
91
-
92
-
93
- @use_kernel_func_from_hub("rotary_pos_emb")
94
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
95
- """Applies Rotary Position Embedding to the query and key tensors.
96
-
97
- Args:
98
- q (`torch.Tensor`): The query tensor.
99
- k (`torch.Tensor`): The key tensor.
100
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
101
- sin (`torch.Tensor`): The sine part of the rotary embedding.
102
- unsqueeze_dim (`int`, *optional*, defaults to 1):
103
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
104
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
105
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
106
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
107
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
108
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
109
- Returns:
110
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
111
- """
112
- cos = cos.unsqueeze(unsqueeze_dim)
113
- sin = sin.unsqueeze(unsqueeze_dim)
114
- q_embed = (q * cos) + (rotate_half(q) * sin)
115
- k_embed = (k * cos) + (rotate_half(k) * sin)
116
- return q_embed, k_embed
117
-
118
-
119
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
120
- """
121
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
122
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
123
- """
124
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
125
- if n_rep == 1:
126
- return hidden_states
127
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
128
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
129
-
130
-
131
- def eager_attention_forward(
132
- module: nn.Module,
133
- query: torch.Tensor,
134
- key: torch.Tensor,
135
- value: torch.Tensor,
136
- attention_mask: torch.Tensor | None,
137
- scaling: float,
138
- dropout: float = 0.0,
139
- **kwargs: Unpack[TransformersKwargs],
140
- ):
141
- key_states = repeat_kv(key, module.num_key_value_groups)
142
- value_states = repeat_kv(value, module.num_key_value_groups)
143
-
144
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
145
- if attention_mask is not None:
146
- attn_weights = attn_weights + attention_mask
147
-
148
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
149
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
150
- attn_output = torch.matmul(attn_weights, value_states)
151
- attn_output = attn_output.transpose(1, 2).contiguous()
152
-
153
- return attn_output, attn_weights
154
-
155
-
156
- @use_kernelized_func(apply_rotary_pos_emb)
157
- class HrmTextAttention(nn.Module):
158
- """Multi-headed attention from 'Attention Is All You Need' paper"""
159
-
160
- def __init__(self, config: HrmTextConfig, layer_idx: int):
161
- super().__init__()
162
- self.config = config
163
- self.layer_idx = layer_idx
164
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
- self.num_key_value_groups = 1 # Uses MHA instead of GQA
166
- self.scaling = self.head_dim**-0.5
167
- self.attention_dropout = config.attention_dropout
168
- self.is_causal = True
169
-
170
- self.q_proj = nn.Linear(
171
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
172
- )
173
- self.k_proj = nn.Linear(
174
- config.hidden_size,
175
- config.num_attention_heads * self.head_dim,
176
- bias=config.attention_bias,
177
- )
178
- self.v_proj = nn.Linear(
179
- config.hidden_size,
180
- config.num_attention_heads * self.head_dim,
181
- bias=config.attention_bias,
182
- )
183
- self.o_proj = nn.Linear(
184
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
185
- )
186
- # Additional sigmoid gate applied at the end
187
- self.gate_proj = nn.Linear(
188
- config.hidden_size,
189
- config.num_attention_heads * self.head_dim,
190
- bias=config.attention_bias,
191
- )
192
-
193
- def forward(
194
- self,
195
- hidden_states: torch.Tensor,
196
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
197
- attention_mask: torch.Tensor | None = None,
198
- past_key_values: Cache | None = None,
199
- cycle_offset: int = 0,
200
- **kwargs: Unpack[TransformersKwargs],
201
- ) -> tuple[torch.Tensor, torch.Tensor]:
202
- input_shape = hidden_states.shape[:-1]
203
- hidden_shape = (*input_shape, -1, self.head_dim)
204
-
205
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
206
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
207
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
208
- gate_states = self.gate_proj(hidden_states).view(hidden_shape)
209
-
210
- cos, sin = position_embeddings
211
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
212
-
213
- if past_key_values is not None:
214
- # Adjust cache slot by `cycle_offset` which is determined by it's current recurrent step through the stacks
215
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx + cycle_offset)
216
-
217
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
218
- self.config._attn_implementation, eager_attention_forward
219
- )
220
- attn_output, attn_weights = attention_interface(
221
- self,
222
- query_states,
223
- key_states,
224
- value_states,
225
- attention_mask,
226
- dropout=0.0 if not self.training else self.attention_dropout,
227
- scaling=self.scaling,
228
- **kwargs,
229
- )
230
-
231
- # Additional sigmoid gating (similar to Qwen3Next)
232
- attn_output = torch.sigmoid(gate_states) * attn_output
233
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
234
- attn_output = self.o_proj(attn_output)
235
- return attn_output, attn_weights
236
-
237
-
238
- class HrmTextDecoderLayer(GradientCheckpointingLayer):
239
- def __init__(self, config: HrmTextConfig, layer_idx: int):
240
- super().__init__()
241
- self.hidden_size = config.hidden_size
242
-
243
- self.self_attn = HrmTextAttention(config=config, layer_idx=layer_idx)
244
-
245
- self.mlp = HrmTextMLP(config)
246
- self.input_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
247
- self.post_attention_layernorm = HrmTextRMSNorm(eps=config.rms_norm_eps)
248
-
249
- def forward(
250
- self,
251
- hidden_states: torch.Tensor,
252
- attention_mask: torch.Tensor | None = None,
253
- position_ids: torch.LongTensor | None = None,
254
- past_key_values: Cache | None = None,
255
- use_cache: bool | None = False,
256
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
257
- **kwargs: Unpack[TransformersKwargs],
258
- ) -> torch.Tensor:
259
- residual = hidden_states
260
- hidden_states = self.input_layernorm(hidden_states)
261
- # Self Attention
262
- hidden_states, _ = self.self_attn(
263
- hidden_states=hidden_states,
264
- attention_mask=attention_mask,
265
- position_ids=position_ids,
266
- past_key_values=past_key_values,
267
- use_cache=use_cache,
268
- position_embeddings=position_embeddings,
269
- **kwargs,
270
- )
271
- hidden_states = residual + hidden_states
272
-
273
- # Fully Connected
274
- residual = hidden_states
275
- hidden_states = self.post_attention_layernorm(hidden_states)
276
- hidden_states = self.mlp(hidden_states)
277
- hidden_states = residual + hidden_states
278
- return hidden_states
279
-
280
-
281
- class HrmTextStack(nn.Module):
282
- """A single transformer stack — used twice inside, once as H module and once as L module"""
283
-
284
- def __init__(self, config: HrmTextConfig):
285
- super().__init__()
286
- self.layers = nn.ModuleList(
287
- [HrmTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers_per_stack)]
288
- )
289
- self.final_norm = HrmTextRMSNorm(eps=config.rms_norm_eps)
290
-
291
- def forward(
292
- self,
293
- hidden_states: torch.Tensor,
294
- attention_mask: torch.Tensor | None = None,
295
- past_key_values: Cache | None = None,
296
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
297
- cycle_offset: int = 0,
298
- **kwargs: Unpack[TransformersKwargs],
299
- ) -> torch.Tensor:
300
- for layer in self.layers:
301
- hidden_states = layer(
302
- hidden_states,
303
- attention_mask=attention_mask,
304
- past_key_values=past_key_values,
305
- position_embeddings=position_embeddings,
306
- cycle_offset=cycle_offset,
307
- **kwargs,
308
- )
309
- return self.final_norm(hidden_states)
310
-
311
-
312
- @auto_docstring
313
- class HrmTextPreTrainedModel(PreTrainedModel):
314
- config: HrmTextConfig
315
- base_model_prefix = "model"
316
- supports_gradient_checkpointing = True
317
- _no_split_modules = ["HrmTextDecoderLayer"]
318
- _skip_keys_device_placement = ["past_key_values"]
319
- _supports_flash_attn = True
320
- _supports_sdpa = True
321
- _supports_flex_attn = True
322
-
323
- _can_compile_fullgraph = True
324
- _supports_attention_backend = True
325
- _can_record_outputs = {
326
- "hidden_states": HrmTextDecoderLayer,
327
- "attentions": HrmTextAttention,
328
- }
329
-
330
- def _check_and_adjust_attn_implementation(
331
- self, attn_implementation: str | None, is_init_check: bool = False, allow_all_kernels: bool = False
332
- ) -> str:
333
- if attn_implementation is not None and self.config.prefix_lm:
334
- _, base_implementation = split_attention_implementation(attn_implementation)
335
- if is_flash_attention_requested(requested_attention_implementation=base_implementation):
336
- raise ValueError(
337
- f"`attn_implementation={attn_implementation!r}` is not supported when "
338
- "`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
339
- "overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
340
- )
341
- return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check, allow_all_kernels)
342
-
343
- @torch.no_grad()
344
- def _init_weights(self, module):
345
- super()._init_weights(module)
346
- if isinstance(module, HrmTextModel):
347
- init.zeros_(module.z_L_init)
348
- # `z_L_init` is the frozen low-cycle initial state and never trains.
349
- module.z_L_init.requires_grad_(False) # trf-ignore: TRF012
350
-
351
-
352
- class HrmTextRotaryEmbedding(nn.Module):
353
- inv_freq: torch.Tensor # fix linting for `register_buffer`
354
-
355
- def __init__(self, config: HrmTextConfig, device=None):
356
- super().__init__()
357
- self.max_seq_len_cached = config.max_position_embeddings
358
- self.original_max_seq_len = config.max_position_embeddings
359
-
360
- self.config = config
361
-
362
- self.rope_type = self.config.rope_parameters["rope_type"]
363
- rope_init_fn: Callable = self.compute_default_rope_parameters
364
- if self.rope_type != "default":
365
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
366
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
367
-
368
- self.register_buffer("inv_freq", inv_freq, persistent=False)
369
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
370
-
371
- @staticmethod
372
- def compute_default_rope_parameters(
373
- config: HrmTextConfig | None = None,
374
- device: Optional["torch.device"] = None,
375
- seq_len: int | None = None,
376
- ) -> tuple["torch.Tensor", float]:
377
- """
378
- Computes the inverse frequencies according to the original RoPE implementation
379
- Args:
380
- config ([`~transformers.PreTrainedConfig`]):
381
- The model configuration.
382
- device (`torch.device`):
383
- The device to use for initialization of the inverse frequencies.
384
- seq_len (`int`, *optional*):
385
- The current sequence length. Unused for this type of RoPE.
386
- Returns:
387
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
388
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
389
- """
390
- base = config.rope_parameters["rope_theta"]
391
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
392
-
393
- attention_factor = 1.0 # Unused in this type of RoPE
394
-
395
- # Compute the inverse frequencies
396
- inv_freq = 1.0 / (
397
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
398
- )
399
- return inv_freq, attention_factor
400
-
401
- @torch.no_grad()
402
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
403
- def forward(self, x, position_ids):
404
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
405
- position_ids_expanded = position_ids[:, None, :].float()
406
-
407
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
408
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
409
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
410
- emb = torch.cat((freqs, freqs), dim=-1)
411
- cos = emb.cos() * self.attention_scaling
412
- sin = emb.sin() * self.attention_scaling
413
-
414
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
415
-
416
-
417
- @auto_docstring
418
- class HrmTextModel(HrmTextPreTrainedModel):
419
- def __init__(self, config: HrmTextConfig):
420
- super().__init__(config)
421
- self.padding_idx = config.pad_token_id
422
- self.vocab_size = config.vocab_size
423
-
424
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
425
- self.rotary_emb = HrmTextRotaryEmbedding(config=config)
426
- self.gradient_checkpointing = False
427
-
428
- self.embedding_scale = config.embedding_scale
429
-
430
- # Recursive module structures
431
- self.L_module = HrmTextStack(config)
432
- self.H_module = HrmTextStack(config)
433
- # Initial state for the low cycle module
434
- self.z_L_init = nn.Parameter(torch.zeros(config.hidden_size), requires_grad=False)
435
-
436
- raw_bp = list(config.L_bp_cycles)
437
- self.L_bp_cycles_padded = [1] * max(0, config.H_cycles - len(raw_bp)) + raw_bp
438
-
439
- # Initialize weights and apply final processing
440
- self.post_init()
441
-
442
- @merge_with_config_defaults
443
- @capture_outputs
444
- @auto_docstring
445
- def forward(
446
- self,
447
- input_ids: torch.LongTensor | None = None,
448
- attention_mask: torch.Tensor | None = None,
449
- position_ids: torch.LongTensor | None = None,
450
- past_key_values: Cache | None = None,
451
- token_type_ids: torch.LongTensor | None = None,
452
- inputs_embeds: torch.FloatTensor | None = None,
453
- use_cache: bool | None = None,
454
- **kwargs: Unpack[TransformersKwargs],
455
- ) -> BaseModelOutputWithPast:
456
- r"""
457
- token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
458
- Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
459
- form a single bidirectional block; all other positions are causal.
460
- """
461
- if (input_ids is None) ^ (inputs_embeds is not None):
462
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
463
-
464
- if inputs_embeds is None:
465
- inputs_embeds = self.embed_tokens(input_ids)
466
- # Additional scaling on the input embeds
467
- inputs_embeds = inputs_embeds * self.embedding_scale
468
-
469
- if use_cache and past_key_values is None:
470
- past_key_values = DynamicCache(config=self.config)
471
-
472
- if position_ids is None:
473
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
474
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
475
- position_ids = position_ids.unsqueeze(0)
476
-
477
- # Create mask with optional prefix-based bidirectionality
478
- mask_kwargs = {
479
- "config": self.config,
480
- "inputs_embeds": inputs_embeds,
481
- "attention_mask": attention_mask,
482
- "past_key_values": past_key_values,
483
- "position_ids": position_ids,
484
- }
485
- is_first_iteration = past_key_values is None or not past_key_values.is_initialized
486
- if token_type_ids is not None and is_first_iteration:
487
- if self.config.prefix_lm:
488
- mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
489
- else:
490
- logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
491
-
492
- attention_mask = create_causal_mask(**mask_kwargs)
493
- position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
494
-
495
- # Hierarchical (H/L)-cycle recurrence
496
- #
497
- # `z_H` - slow / high-level state
498
- hidden_states_high_cycle = inputs_embeds
499
- # `z_L` - fast / low-level state
500
- hidden_states_low_cycle = (
501
- self.z_L_init.to(dtype=hidden_states_high_cycle.dtype, device=hidden_states_high_cycle.device)
502
- .expand_as(hidden_states_high_cycle)
503
- .contiguous()
504
- )
505
-
506
- # Cache-slot layout under the recurrent forward:
507
- #
508
- # slot(h, l, layer) = (h * (L_cycles + 1) + l) * num_layers_per_stack + layer
509
- # ^— L-stack invocation at (h, l)
510
- # slot(h, H, layer) = (h * (L_cycles + 1) + L_cycles) * num_layers_per_stack + layer
511
- # ^— trailing H-stack invocation
512
- #
513
- # That totals `num_layers_per_stack * H_cycles * (L_cycles + 1)` slots, i.e. the `config.num_hidden_layers`.
514
- num_layers_per_stack = self.config.num_layers_per_stack
515
- for high_cycle_idx in range(self.config.H_cycles):
516
- # `L_bp_cycles` k-step grad trick: only the trailing `num_grad_iterations` of the
517
- # `L_cycles` inner iterations propagate gradients; earlier iterations run under
518
- # `torch.no_grad()` to bound activation memory.
519
- num_grad_iterations = (
520
- self.L_bp_cycles_padded[high_cycle_idx] if high_cycle_idx < len(self.L_bp_cycles_padded) else 1
521
- )
522
- grad_threshold = self.config.L_cycles - num_grad_iterations
523
- for low_cycle_idx in range(self.config.L_cycles):
524
- cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + low_cycle_idx) * num_layers_per_stack
525
- ctx = nullcontext() if low_cycle_idx >= grad_threshold else torch.no_grad()
526
- with ctx:
527
- hidden_states_low_cycle = self.L_module(
528
- hidden_states_low_cycle.to(hidden_states_high_cycle.device) + hidden_states_high_cycle,
529
- attention_mask=attention_mask,
530
- past_key_values=past_key_values,
531
- position_embeddings=position_embeddings,
532
- position_ids=position_ids,
533
- cycle_offset=cycle_offset,
534
- **kwargs,
535
- )
536
-
537
- cycle_offset = (high_cycle_idx * (self.config.L_cycles + 1) + self.config.L_cycles) * num_layers_per_stack
538
-
539
- hidden_states_high_cycle = self.H_module(
540
- hidden_states_high_cycle + hidden_states_low_cycle.to(hidden_states_high_cycle.device),
541
- attention_mask=attention_mask,
542
- past_key_values=past_key_values,
543
- position_embeddings=position_embeddings,
544
- position_ids=position_ids,
545
- cycle_offset=cycle_offset,
546
- **kwargs,
547
- )
548
-
549
- return BaseModelOutputWithPast(
550
- last_hidden_state=hidden_states_high_cycle,
551
- past_key_values=past_key_values,
552
- )
553
-
554
-
555
- @auto_docstring
556
- class HrmTextForCausalLM(HrmTextPreTrainedModel, GenerationMixin):
557
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
558
- _tp_plan = {"lm_head": "colwise_gather_output"}
559
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
560
-
561
- def __init__(self, config):
562
- super().__init__(config)
563
- self.model = HrmTextModel(config)
564
- self.vocab_size = config.vocab_size
565
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
566
-
567
- # Initialize weights and apply final processing
568
- self.post_init()
569
-
570
- @can_return_tuple
571
- @auto_docstring
572
- def forward(
573
- self,
574
- input_ids: torch.LongTensor | None = None,
575
- attention_mask: torch.Tensor | None = None,
576
- position_ids: torch.LongTensor | None = None,
577
- past_key_values: Cache | None = None,
578
- token_type_ids: torch.LongTensor | None = None,
579
- inputs_embeds: torch.FloatTensor | None = None,
580
- labels: torch.LongTensor | None = None,
581
- use_cache: bool | None = None,
582
- logits_to_keep: int | torch.Tensor = 0,
583
- **kwargs: Unpack[TransformersKwargs],
584
- ) -> CausalLMOutputWithPast:
585
- r"""
586
- token_type_ids (`torch.LongTensor` of shape `(batch, seq_len)`, *optional*):
587
- Per-position bidirectional/causal indicator. Tokens with `token_type_ids == 1`
588
- form a single bidirectional block; all other positions are causal.
589
- """
590
- outputs: BaseModelOutputWithPast = self.model(
591
- input_ids=input_ids,
592
- attention_mask=attention_mask,
593
- position_ids=position_ids,
594
- past_key_values=past_key_values,
595
- token_type_ids=token_type_ids,
596
- inputs_embeds=inputs_embeds,
597
- use_cache=use_cache,
598
- **kwargs,
599
- )
600
-
601
- hidden_states = outputs.last_hidden_state
602
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
603
- logits = self.lm_head(hidden_states[:, slice_indices, :])
604
-
605
- loss = None
606
- if labels is not None:
607
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
608
-
609
- return CausalLMOutputWithPast(
610
- loss=loss,
611
- logits=logits,
612
- past_key_values=outputs.past_key_values,
613
- hidden_states=outputs.hidden_states,
614
- attentions=outputs.attentions,
615
- )
616
-
617
- @staticmethod
618
- def create_masks_for_generate(
619
- config: PreTrainedConfig,
620
- inputs_embeds: torch.Tensor,
621
- attention_mask: torch.Tensor | None,
622
- past_key_values: Cache | None,
623
- position_ids: torch.Tensor | None,
624
- token_type_ids: torch.Tensor | None = None,
625
- is_first_iteration: bool | None = False,
626
- **kwargs,
627
- ) -> dict:
628
- mask_kwargs = {
629
- "config": config,
630
- "inputs_embeds": inputs_embeds,
631
- "attention_mask": attention_mask,
632
- "past_key_values": past_key_values,
633
- "position_ids": position_ids,
634
- }
635
- if token_type_ids is not None and is_first_iteration:
636
- if config.prefix_lm:
637
- mask_kwargs["block_sequence_ids"] = torch.where(token_type_ids == 1, 0, -1)
638
- else:
639
- logger.warning_once("`token_type_ids` was provided but `config.prefix_lm=False`; ignoring it.")
640
-
641
- return create_masks_for_generate(**mask_kwargs)
642
-
643
-
644
- __all__ = ["HrmTextForCausalLM", "HrmTextModel", "HrmTextPreTrainedModel"]