from typing import Callable, Optional import math import torch from torch import nn from transformers.activations import ACT2FN from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring from transformers.utils.generic import check_model_inputs from .configuration_hinvec import HinvecConfig SUPPORTED_SLIDING_BACKENDS=["flash_attention_3", "flash_attention_2"] class HinvecMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.""" cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): """Standard eager attention without sliding window optimization.""" key_states = repeat_kv(key, 1) value_states = repeat_kv(value, 1) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attention_mask = attention_mask[:, None, None, :] attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class HinvecAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.config = config self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.num_heads = config.num_attention_heads self.head_dim = int(config.hidden_size / config.num_attention_heads) self.embed_dim = config.hidden_size self.query = nn.Linear(config.hidden_size, self.embed_dim) self.key = nn.Linear(config.hidden_size, self.embed_dim) self.value = nn.Linear(config.hidden_size, self.embed_dim) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) self.layer_idx = layer_idx # Initialize sliding_window and sliding_window self.sliding_window = None if config.layer_types[layer_idx] == "sliding_attention": self.sliding_window = config.sliding_window assert self.sliding_window % 2 == 0, ( f"`sliding_window` for layer {self.layer_idx} has to be an even value. Given {self.sliding_window}" ) assert self.sliding_window > 0, ( f"`sliding_window` for layer {self.layer_idx} has to be positive. Given {self.sliding_window}" ) self.one_sided_attn_window_size = self.sliding_window // 2 # Store attention_dropout as instance variable self.attention_dropout = config.attention_dropout def forward( self, hidden_states, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask=None, position_ids: Optional[torch.LongTensor] = None, is_index_masked=None, output_attentions=False, **kwargs, ): """ [`HinvecAttention`] expects *len(hidden_states)* to be multiple of *sliding_window*. Padding to *sliding_window* happens in [`LongformerModel.forward`] to avoid redoing the padding on each layer. The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to: - -10000: no attention - 0: local attention """ print(self.layer_idx) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) # project hidden states query_vectors = self.query(hidden_states).view(hidden_shape).transpose(1, 2) key_vectors = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_vectors = self.value(hidden_states).view(hidden_shape).transpose(1, 2) batch_size, seq_len, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim, ( f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" ) cos, sin = position_embeddings query_vectors, key_vectors = apply_rotary_pos_emb(query_vectors, key_vectors, cos, sin) if self.sliding_window and self.config._attn_implementation == "eager": # normalize query query_vectors /= math.sqrt(self.head_dim) query_vectors = query_vectors.transpose(1, 2) key_vectors = key_vectors.transpose(1, 2) attn_scores = self._sliding_chunks_query_key_matmul( query_vectors, key_vectors, self.one_sided_attn_window_size ) # values to pad for attention probs remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None] # cast to fp32/fp16 then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min ) # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size ) # pad local attention probs attn_scores += diagonal_mask assert list(attn_scores.size()) == [ batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1, ], ( f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}," f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" ) attn_probs = nn.functional.softmax( attn_scores, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) attn_probs = attn_probs.type_as(attn_scores) # free memory del attn_scores # apply dropout attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) value_vectors = value_vectors.transpose(1, 2) # compute local attn only attn_output = self._sliding_chunks_matmul_attn_probs_value( attn_probs, value_vectors, self.one_sided_attn_window_size ) assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() attn_weights = attn_probs.transpose(1, 2) elif self.sliding_window: if self.config._attn_implementation in SUPPORTED_SLIDING_BACKENDS: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_vectors, key_vectors, value_vectors, (attention_mask < 0)*torch.finfo(query_vectors.dtype).min, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) else: # Construct sliding window mask seq_len = query_vectors.size(-2) sw_mask = self._build_sliding_window_mask(seq_len, query_vectors.device, query_vectors.dtype) print(sw_mask) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_vectors, key_vectors, value_vectors, # Combine with user-provided mask by addition (attention_mask < 0) * torch.finfo(query_vectors.dtype).min + sw_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, is_causal=False, # IMPORTANT **kwargs, ) else: # Use standard attention implementation attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn = (attention_mask < 0)*torch.finfo(query_vectors.dtype).min attn = attn[:, None, :, None].expand(-1, 1, attn.size(1), attn.size(1)) attn_output, attn_weights = attention_interface( self, query_vectors, key_vectors, value_vectors, attn, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights def _build_sliding_window_mask(self, seq_len, device, dtype): """ Creates an additive attention mask of shape (1, 1, seq_len, seq_len) where each token attends only to a window around itself. Allowed positions: i-w ... i+w Others masked with -∞ """ w = self.one_sided_attn_window_size # Base mask: True means "mask this (disallowed)" mask = torch.ones(seq_len, seq_len, device=device, dtype=torch.bool) # Fill allowed window as False arange = torch.arange(seq_len, device=device) for i in range(seq_len): left = max(0, i - w) right = min(seq_len, i + w + 1) mask[i, left:right] = False # Convert boolean mask → additive logit mask # True → -inf, False → 0 mask = mask.masked_fill(mask, torch.finfo(dtype).min) # Shape to (batch, heads, Q, K) broadcastable format return mask.view(1, 1, seq_len, seq_len) @staticmethod def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): """pads rows and then flips rows and columns""" hidden_states_padded = nn.functional.pad( hidden_states_padded, padding ) # padding value is not important because it will be overwritten hidden_states_padded = hidden_states_padded.view( *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2) ) return hidden_states_padded @staticmethod def _pad_and_diagonalize(chunked_hidden_states): """ shift every row 1 step right, converting columns into diagonals. Example: ```python chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, -1.8348, 0.7672, 0.2986, 0.0285, -0.7584, 0.4206, -0.0405, 0.1599, 2.0514, -1.1600, 0.5372, 0.2629, ] window_overlap = num_rows = 4 ``` (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] """ total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() chunked_hidden_states = nn.functional.pad( chunked_hidden_states, (0, window_overlap + 1) ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten chunked_hidden_states = chunked_hidden_states.view( total_num_heads, num_chunks, -1 ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap chunked_hidden_states = chunked_hidden_states[ :, :, :-window_overlap ] # total_num_heads x num_chunks x window_overlap*window_overlap chunked_hidden_states = chunked_hidden_states.view( total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim ) chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] return chunked_hidden_states @staticmethod def _chunk(hidden_states, window_overlap, onnx_export: bool = False): """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" if not onnx_export: # non-overlapping chunks of size = 2w hidden_states = hidden_states.view( hidden_states.size(0), torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"), window_overlap * 2, hidden_states.size(2), ) # use `as_strided` to make the chunks overlap with an overlap size = window_overlap chunk_size = list(hidden_states.size()) chunk_size[1] = chunk_size[1] * 2 - 1 chunk_stride = list(hidden_states.stride()) chunk_stride[1] = chunk_stride[1] // 2 return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) # When exporting to ONNX, use this separate logic # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export # TODO replace this with # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) # once `unfold` is supported # the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow chunk_size = [ hidden_states.size(0), torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1, window_overlap * 2, hidden_states.size(2), ] overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device) for chunk in range(chunk_size[1]): overlapping_chunks[:, chunk, :, :] = hidden_states[ :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, : ] return overlapping_chunks @staticmethod def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) beginning_mask = beginning_mask_2d[None, :, None, :] ending_mask = beginning_mask.flip(dims=(1, 3)) beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_mask = beginning_mask.expand(beginning_input.size()) input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like( beginning_input, -float("inf") ).where(beginning_mask.bool(), beginning_input) ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_mask = ending_mask.expand(ending_input.size()) input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like( ending_input, -float("inf") ).where(ending_mask.bool(), ending_input) def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): """ Matrix multiplication of query and key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an overlap of size window_overlap """ batch_size, seq_len, num_heads, head_dim = query.size() assert seq_len % (window_overlap * 2) == 0, ( f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" ) assert query.size() == key.size() chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 query = query.reshape(batch_size * num_heads, seq_len, head_dim) key = key.reshape(batch_size * num_heads, seq_len, head_dim) query = self._chunk(query, window_overlap, getattr(self.config, "onnx_export", False)) key = self._chunk(key, window_overlap, getattr(self.config, "onnx_export", False)) # matrix multiplication # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply # convert diagonals into columns diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) ) # allocate space for the overall attention matrix where the chunks are combined. The last dimension # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to # window_overlap previous words). The following column is attention score from each word to itself, then # followed by window_overlap columns for the upper triangle. diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) ) # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions # - copying the main diagonal and the upper triangle diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ :, :, :window_overlap, : window_overlap + 1 ] diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ :, -1, window_overlap:, : window_overlap + 1 ] # - copying the lower triangle diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ :, :, -(window_overlap + 1) : -1, window_overlap + 1 : ] diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ :, 0, : window_overlap - 1, 1 - window_overlap : ] # separate batch_size and num_heads dimensions again diagonal_attention_scores = diagonal_attention_scores.view( batch_size, num_heads, seq_len, 2 * window_overlap + 1 ).transpose(2, 1) self._mask_invalid_locations(diagonal_attention_scores, window_overlap) return diagonal_attention_scores def _sliding_chunks_matmul_attn_probs_value( self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int ): """ Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the same shape as `attn_probs` """ batch_size, seq_len, num_heads, head_dim = value.size() assert seq_len % (window_overlap * 2) == 0 assert attn_probs.size()[:3] == value.size()[:3] assert attn_probs.size(3) == 2 * window_overlap + 1 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap chunked_attn_probs = attn_probs.transpose(1, 2).reshape( batch_size * num_heads, torch.div(seq_len, window_overlap, rounding_mode="trunc"), window_overlap, 2 * window_overlap + 1, ) # group batch_size and num_heads dimensions into one value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) # pad seq_len with w at the beginning of the sequence and another window overlap at the end padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1) # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) chunked_value_stride = padded_value.stride() chunked_value_stride = ( chunked_value_stride[0], window_overlap * chunked_value_stride[1], chunked_value_stride[1], chunked_value_stride[2], ) chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) @use_kernel_forward_from_hub("RMSNorm") class HinvecRMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ HinvecRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class HinvecEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: HinvecConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.self_attn = HinvecAttention(config=config, layer_idx=layer_idx) self.mlp = HinvecMLP(config) self.input_layernorm = HinvecRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = HinvecRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: is_index_masked = attention_mask < 0 residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings, is_index_masked=is_index_masked, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring class HinvecPreTrainedModel(PreTrainedModel): config: HinvecConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["HinvecEncoderLayer"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": HinvecEncoderLayer, "attentions": HinvecAttention, } class HinvecRotaryEmbedding(nn.Module): inv_freq: torch.Tensor def __init__(self, config: HinvecConfig, device=None): super().__init__() if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @auto_docstring class HinvecModel(HinvecPreTrainedModel): config_class = HinvecConfig def __init__(self, config: HinvecConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [HinvecEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = HinvecRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = HinvecRotaryEmbedding(config=config) self.gradient_checkpointing = False self.has_sliding_layers = "sliding_attention" in self.config.layer_types # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def _pad_to_window_size( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, position_ids: torch.Tensor, inputs_embeds: torch.Tensor, pad_token_id: int, ): """A helper function to pad tokens and mask to work with implementation of Longformer self-attention.""" # padding sliding_window = ( self.config.sliding_window if isinstance(self.config.sliding_window, int) else max(self.config.sliding_window) ) assert sliding_window % 2 == 0, f"`sliding_window` should be an even value. Given {sliding_window}" input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape batch_size, seq_len = input_shape[:2] padding_len = (sliding_window - seq_len % sliding_window) % sliding_window # this path should be recorded in the ONNX export, it is fine with padding_len == 0 as well if padding_len > 0: logger.warning_once( f"Input ids are automatically padded to be a multiple of `config.sliding_window`: {sliding_window}" ) if input_ids is not None: input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id) if position_ids is not None: # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id) if inputs_embeds is not None: input_ids_padding = inputs_embeds.new_full( (batch_size, padding_len), self.config.pad_token_id, dtype=torch.long, ) inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) attention_mask = nn.functional.pad( attention_mask, (0, padding_len), value=0 ) # no attention on the padding tokens token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] if position_ids is None: position_ids = torch.arange(seq_len, device=device).unsqueeze(0) position_ids = position_ids.expand(batch_size, -1) # 1. ----- CREATE ATTENTION MASK ----- if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) # 2. ----- PAD TO WINDOW SIZE ----- sliding_window = ( self.config.sliding_window if isinstance(self.config.sliding_window, int) else max(self.config.sliding_window) ) padding_len = (sliding_window - seq_len % sliding_window) % sliding_window if padding_len > 0: # Pad input_ids if available if input_ids is not None: input_ids = torch.nn.functional.pad( input_ids, (0, padding_len), value=self.config.pad_token_id ) # Pad position_ids position_ids = torch.nn.functional.pad( position_ids, (0, padding_len), value=self.config.pad_token_id ) # Pad inputs_embeds input_ids_padding = inputs_embeds.new_full( (batch_size, padding_len), self.config.pad_token_id, dtype=torch.long, ) inputs_embeds_padding = self.embed_tokens(input_ids_padding) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) # Pad attention_mask with 0s (no attention on padding) attention_mask = torch.nn.functional.pad( attention_mask, (0, padding_len), value=0 ) # 3. ----- CREATE PADDING MASK ----- # Convert tri-level mask to continuous mask for attention # Tokens with attention_mask == 0 get -inf, others get 0 expanded_mask = (attention_mask > 0).to(inputs_embeds.dtype) padding_mask = (1.0 - expanded_mask) * torch.finfo(inputs_embeds.dtype).min hidden_states = inputs_embeds # Create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask=padding_mask, position_ids=position_ids, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, ) __all__ = [ "HinvecPreTrainedModel", "HinvecModel", "HinvecRMSNorm" ]