File size: 7,629 Bytes
3dcb68a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
# This file was automatically generated from src/transformers/models/hrm_text/modular_hrm_text.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_hrm_text.py file directly. One of our CI enforces this.
# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
# Copyright 2026 The Sapient AI Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub.dataclasses import strict
from transformers.configuration_utils import PreTrainedConfig
from transformers.modeling_rope_utils import RopeParameters
from transformers.utils import auto_docstring
from transformers.utils.generic import is_flash_attention_requested, split_attention_implementation
from transformers.utils.type_validators import interval
@auto_docstring(checkpoint="sapientinc/HRM-Text-1B")
@strict
class HrmTextConfig(PreTrainedConfig):
r"""
H_cycles (`int`, *optional*, defaults to 2):
Number of high-level cycles.
L_cycles (`int`, *optional*, defaults to 3):
Number of low-level cycles per H-cycle.
L_bp_cycles (`list[int]`, *optional*, defaults to `[2]`):
Training-time gradient-routing list; left-padded with `1`s up to `L_cycles` inside the model.
Inference-time no-op.
embedding_scale (`float`, *optional*):
Token-embedding multiplier. If `None`, defaults to `1 / initializer_range`.
prefix_lm (`bool`, *optional*, defaults to `True`):
Instruction tokens attend bidirectionally, response tokens attend causally.
num_layers_per_stack (`int`, *optional*):
Real number of transformer blocks inside each
of the H / L stacks. Set automatically on first construction: the value passed as
`num_hidden_layers` is remembered here and `num_hidden_layers` is then rewritten to
`num_layers_per_stack * H_cycles * (L_cycles + 1)` so that
`DynamicCache(config=...)` pre-allocates one slot per unique attention invocation
under the recurrent forward. Do not set this directly on first construction β pass
the real per-stack count as `num_hidden_layers` and let `__post_init__` split it.
"""
model_type = "hrm_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
**{f"{stack}.layers.*.self_attn.q_proj": "colwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.self_attn.k_proj": "colwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.self_attn.v_proj": "colwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.self_attn.gate_proj": "colwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.self_attn.o_proj": "rowwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.mlp.gate_proj": "colwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.mlp.up_proj": "colwise" for stack in ("L_module", "H_module")},
**{f"{stack}.layers.*.mlp.down_proj": "rowwise" for stack in ("L_module", "H_module")},
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
vocab_size: int = 151808
hidden_size: int = 1536
intermediate_size: int = 4096
num_hidden_layers: int = 16
num_attention_heads: int = 12
hidden_act: str = "silu"
max_position_embeddings: int = 2048
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
rms_norm_eps: float = 1e-6
use_cache: bool = True
pad_token_id: int | None = None
bos_token_id: int | None = None
eos_token_id: int | list[int] | None = None
tie_word_embeddings: bool = False
rope_parameters: RopeParameters | dict | None = None
attention_bias: bool = False
attention_dropout: int | float | None = 0.0
mlp_bias: bool = False
head_dim: int = 128
H_cycles: int = 2
L_cycles: int = 3
L_bp_cycles: list[int] | None = None
embedding_scale: float | None = None
prefix_lm: bool = True
num_layers_per_stack: int | None = None # Usually inferred in post init
def __post_init__(self, **kwargs):
if self.L_bp_cycles is None:
# Default `[2]` = backprop only the last 2 L-iterations per H-cycle (training-time
# gradient-routing knob). Left-padding to length `L_cycles` is performed inside
# [`HrmTextModel`] since it depends on `L_cycles`.
self.L_bp_cycles = [2]
if self.embedding_scale is None:
self.embedding_scale = 1.0 / self.initializer_range
if self.num_layers_per_stack is None:
# Initial construction, or legacy checkpoint where `num_hidden_layers` carries the
# real per-stack count: remember that value and rewrite `num_hidden_layers` to the
# inflated total, so standard HF cache allocation gives us one slot per unique
# attention invocation. Serialised configs round-trip as (inflated, real) pairs.
self.num_layers_per_stack = self.num_hidden_layers
self.num_hidden_layers = self.num_layers_per_stack * self.H_cycles * (self.L_cycles + 1)
super().__post_init__(**kwargs)
def validate_architecture(self):
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
f"heads ({self.num_attention_heads})."
)
@property
def _attn_implementation(self):
return self._attn_implementation_internal
@_attn_implementation.setter
def _attn_implementation(self, value: str | dict | None):
if value is not None and self.prefix_lm:
_, base_implementation = split_attention_implementation(value)
if is_flash_attention_requested(requested_attention_implementation=base_implementation):
raise ValueError(
f"`attn_implementation={value!r}` is not supported when "
"`config.prefix_lm=True`: FlashAttention cannot represent the PrefixLM 4-D mask "
"overlay. Use `'sdpa'` (default) or `'flex_attention'`, or set `config.prefix_lm=False`."
)
PreTrainedConfig._attn_implementation.__set__(self, value)
__all__ = ["HrmTextConfig"]
|