"""Apriel2 HuggingFace model implementation.""" import math import random from types import SimpleNamespace from typing import Any, Optional, TypedDict, Union import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import nn from transformers import GenerationMixin, PreTrainedModel from transformers.cache_utils import Cache from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.llama.modeling_llama import eager_attention_forward from transformers.models.mistral.modeling_mistral import MistralMLP, MistralRMSNorm, apply_rotary_pos_emb from transformers.processing_utils import Unpack from transformers.utils import logging from transformers.utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, is_torch_flex_attn_available, ) from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig # GDN implementation - matches Fast-LLM's gdn.py exactly try: from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule except ImportError: chunk_gated_delta_rule = None fused_recurrent_gated_delta_rule = None try: from fla.modules.fused_norm_gate import rms_norm_gated except ImportError: rms_norm_gated = None # KDA implementation - matches Fast-LLM's kda.py try: from fla.ops.kda import chunk_kda, fused_recurrent_kda from fla.ops.kda.gate import fused_kda_gate except ImportError: chunk_kda = None fused_recurrent_kda = None fused_kda_gate = None try: from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn from causal_conv1d import causal_conv1d_update as _causal_conv1d_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: _causal_conv1d_fn = None _causal_conv1d_update = None selective_scan_fn = None selective_state_update = None is_fast_path_available = is_mamba_ssm_available() and is_causal_conv1d_available() if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask else: BlockMask = torch.Tensor logger = logging.get_logger(__name__) # ============================================================================= # Cache Classes # ============================================================================= class _AttentionCache: __slots__ = ["key", "value", "window", "cumulative_length"] def __init__(self, window=None): self.key = None self.value = None self.window = window self.cumulative_length = 0 def update(self, key, value): new_tokens = key.shape[-2] self.cumulative_length += new_tokens if self.key is None: if self.window and key.shape[-2] > self.window: self.key = key[..., -self.window :, :].contiguous() self.value = value[..., -self.window :, :].contiguous() else: self.key = key.contiguous() self.value = value.contiguous() else: if self.window: self.key = self._window(self.key, key) self.value = self._window(self.value, value) else: self.key = torch.cat([self.key, key], -2) self.value = torch.cat([self.value, value], -2) return self.key, self.value def _window(self, cache, new): if cache.shape[-2] == self.window and new.shape[-2] == 1: cache = cache.roll(-1, -2) cache[..., -1:, :] = new return cache return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous() def reset(self): self.key = None self.value = None self.cumulative_length = 0 def reorder(self, beam_idx): if self.key is not None: self.key = self.key.index_select(0, beam_idx.to(self.key.device)) self.value = self.value.index_select(0, beam_idx.to(self.value.device)) def crop(self, max_length): if self.key is not None: self.key = self.key[..., :max_length, :] self.value = self.value[..., :max_length, :] self.cumulative_length = self.key.shape[-2] def batch_repeat(self, repeats): if self.key is not None: self.key = self.key.repeat_interleave(repeats, dim=0) self.value = self.value.repeat_interleave(repeats, dim=0) def batch_select(self, indices): if self.key is not None: self.key = self.key.index_select(0, indices.to(self.key.device)) self.value = self.value.index_select(0, indices.to(self.value.device)) @property def is_initialized(self): return self.key is not None @property def batch_size(self): return self.key.shape[0] if self.key is not None else None class _SSMCache: __slots__ = ["conv", "recurrent"] def __init__(self): self.conv = None self.recurrent = None def reset(self): self.conv = None self.recurrent = None def reorder(self, beam_idx): if self.conv is not None: if isinstance(self.conv, tuple): self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv) else: self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device)) if self.recurrent is not None: self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device)) def crop(self, max_length): pass # SSM caches don't have sequence dimension to crop def batch_repeat(self, repeats): if self.conv is not None: if isinstance(self.conv, tuple): self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv) else: self.conv = self.conv.repeat_interleave(repeats, dim=0) if self.recurrent is not None: self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0) def batch_select(self, indices): if self.conv is not None: if isinstance(self.conv, tuple): self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv) else: self.conv = self.conv.index_select(0, indices.to(self.conv.device)) if self.recurrent is not None: self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device)) @property def is_initialized(self): return self.conv is not None @property def batch_size(self): if self.conv is None: return None if isinstance(self.conv, tuple): return self.conv[0].shape[0] return self.conv.shape[0] class _DummyCacheLayer: pass class Apriel2Cache(Cache): def __init__(self, config): super().__init__(layer_class_to_replicate=_DummyCacheLayer) self.config = config n = config.decoder["num_blocks"] self.layers = [] self.mixer_types = [] self.active_mixers = [None] * n for i in range(n): block = config.get_block_config(i) mixer = block.get("mixer", {}) mtype = mixer.get("type", "attention") if mtype == "stochastic": sub = {} main = mixer.get("main_mixer_name") for name, cfg in mixer.get("mixers", {}).items(): if cfg.get("type") == "attention": sub[name] = _AttentionCache(cfg.get("window_size")) else: sub[name] = _SSMCache() self.layers.append(sub) self.mixer_types.append(mixer["mixers"][main].get("type") if main else "attention") elif mtype == "attention": self.layers.append(_AttentionCache(mixer.get("window_size"))) self.mixer_types.append("attention") else: self.layers.append(_SSMCache()) self.mixer_types.append(mtype) def update(self, key_states, value_states, layer_idx, cache_kwargs=None): layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer is None: raise RuntimeError(f"Stochastic layer {layer_idx} needs active_mixer set") return layer[mixer].update(key_states, value_states) return layer.update(key_states, value_states) def set_active_mixer(self, layer_idx, mixer_name): self.active_mixers[layer_idx] = mixer_name def get_seq_length(self, layer_idx=0): """Returns the cumulative sequence length of tokens seen by the cache. For sliding window caches, this returns the total tokens seen (not just cached). This matches HuggingFace's DynamicSlidingWindowLayer behavior. """ layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer and isinstance(layer[mixer], _AttentionCache): return layer[mixer].cumulative_length return 0 if isinstance(layer, _AttentionCache): return layer.cumulative_length return 0 def get_max_cache_shape(self, layer_idx=0): layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer and isinstance(layer[mixer], _AttentionCache): return layer[mixer].window elif isinstance(layer, _AttentionCache): return layer.window return None def get_mask_sizes(self, cache_position, layer_idx): """Return the length and offset of the cache, used to generate the attention mask. For standard (non-sliding) attention: kv_offset = 0 (KV[0] corresponds to sequence position 0) kv_length = cumulative_length + query_length For sliding window attention: kv_offset = max(cumulative_length - window + 1, 0) kv_length = min(cumulative_length, window - 1) + query_length For SSM/linear layers: kv_offset = 0, kv_length = query_length (no KV cache to attend to) """ query_length = cache_position.shape[0] layer = self.layers[layer_idx] # Handle stochastic layers by getting the active mixer's cache if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer is None: # No active mixer set, return defaults return query_length, 0 cache = layer[mixer] else: cache = layer # SSM layers don't have KV cache for attention mask purposes if isinstance(cache, _SSMCache): return query_length, 0 # Attention cache - check if sliding window if isinstance(cache, _AttentionCache): cumulative = cache.cumulative_length window = cache.window if window is not None: # Sliding window attention kv_offset = max(cumulative - window + 1, 0) if cumulative >= window: kv_length = window - 1 + query_length else: kv_length = cumulative + query_length else: # Full attention kv_offset = 0 kv_length = cumulative + query_length return kv_length, kv_offset # Fallback return query_length, 0 @property def has_previous_state(self): return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches()) @property def key_cache(self): return _LayerListAccessor(self, "key") @property def value_cache(self): return _LayerListAccessor(self, "value") @property def conv_states(self): return _LayerListAccessor(self, "conv") @property def recurrent_states(self): return _LayerListAccessor(self, "recurrent") def _iter_caches(self): """Iterate over all leaf cache objects (flattening stochastic layer dicts).""" for layer in self.layers: if isinstance(layer, dict): yield from layer.values() else: yield layer def reorder_cache(self, beam_idx): for cache in self._iter_caches(): cache.reorder(beam_idx) def reset(self): for cache in self._iter_caches(): cache.reset() def crop(self, max_length): for cache in self._iter_caches(): cache.crop(max_length) def batch_repeat_interleave(self, repeats): for cache in self._iter_caches(): cache.batch_repeat(repeats) def batch_select_indices(self, indices): for cache in self._iter_caches(): cache.batch_select(indices) @property def is_compileable(self): return False @property def is_initialized(self): return any(cache.is_initialized for cache in self._iter_caches()) @property def is_sliding(self): result = [] for layer in self.layers: if isinstance(layer, dict): has_sliding = any( isinstance(cache, _AttentionCache) and cache.window is not None for cache in layer.values() ) result.append(has_sliding) elif isinstance(layer, _AttentionCache): result.append(layer.window is not None) else: result.append(False) return result @property def max_batch_size(self): for cache in self._iter_caches(): bs = cache.batch_size if bs is not None: return bs return None @property def max_cache_len(self): windows = [ cache.window for cache in self._iter_caches() if isinstance(cache, _AttentionCache) and cache.window is not None ] return min(windows) if windows else None def __len__(self): return len(self.layers) def __getitem__(self, idx): layer = self.layers[idx] if isinstance(layer, dict): mixer = self.active_mixers[idx] if mixer and isinstance(layer[mixer], _AttentionCache): c = layer[mixer] if c.key is not None: return c.key, c.value elif isinstance(layer, _AttentionCache): if layer.key is not None: return layer.key, layer.value for i, l in enumerate(self.layers): if isinstance(l, _AttentionCache) and l.key is not None: return torch.empty((0,), device=l.key.device, dtype=l.key.dtype), torch.empty( (0,), device=l.key.device, dtype=l.key.dtype ) elif isinstance(l, dict): for c in l.values(): if isinstance(c, _AttentionCache) and c.key is not None: return torch.empty((0,), device=c.key.device, dtype=c.key.dtype), torch.empty( (0,), device=c.key.device, dtype=c.key.dtype ) return torch.empty((0,)), torch.empty((0,)) class _LayerListAccessor: __slots__ = ["cache", "attr"] def __init__(self, cache, attr): self.cache = cache self.attr = attr def __getitem__(self, idx): layer = self.cache.layers[idx] if isinstance(layer, dict): mixer = self.cache.active_mixers[idx] if mixer is None: raise RuntimeError( f"Stochastic layer {idx} requires set_active_mixer() to be called before accessing cache. " f"Available mixers: {list(layer.keys())}" ) return getattr(layer[mixer], self.attr) return getattr(layer, self.attr, None) def __setitem__(self, idx, value): layer = self.cache.layers[idx] if isinstance(layer, dict): mixer = self.cache.active_mixers[idx] if mixer is None: raise RuntimeError( f"Stochastic layer {idx} requires set_active_mixer() to be called before accessing cache. " f"Available mixers: {list(layer.keys())}" ) setattr(layer[mixer], self.attr, value) elif hasattr(layer, self.attr): setattr(layer, self.attr, value) # ============================================================================= # TypedDict Classes # ============================================================================= class BlockSequenceKwargs(TypedDict, total=False): attention_mask: Optional[torch.Tensor] position_ids: Optional[torch.LongTensor] cache_position: Optional[torch.LongTensor] past_key_values: Optional[Apriel2Cache] output_attentions: bool output_hidden_states: bool use_cache: bool class PreprocessingOutput(TypedDict, total=False): position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] attention_mask: Optional[torch.Tensor] class CausalConv1d(nn.Conv1d): """ Causal 1D convolution that pads only on the left side. Subclasses nn.Conv1d for weight storage/checkpoint compatibility, but overrides forward to use proper causal (left-only) padding instead of nn.Conv1d's symmetric padding. Supports: - Prefill mode: process full sequence, optionally return final state for caching - Decode mode: single-token update using cached conv state Requires causal_conv1d library for CUDA kernels (no PyTorch fallback). """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, activation: str = "silu", **kwargs, ): if not is_fast_path_available: raise ImportError( "CausalConv1d requires CUDA kernels from causal_conv1d and mamba_ssm. " "Install with: pip install causal-conv1d mamba-ssm" ) # Remove padding from kwargs since we handle it ourselves kwargs.pop("padding", None) super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=0, # No built-in padding; we handle it in forward **kwargs, ) self._activation = activation @property def _weight(self) -> torch.Tensor: """Weight in [dim, kernel_size] format for causal_conv1d functions.""" return self.weight.squeeze(1) def forward( self, x: torch.Tensor, conv_state: torch.Tensor | None = None, return_final_state: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Apply causal convolution. Args: x: Input tensor [batch, dim, seq_len] conv_state: Previous conv state [batch, dim, kernel_size-1] for continuing from cached state. If None, starts fresh. return_final_state: If True, return (output, final_state) tuple where final_state can be used for subsequent decode steps. Returns: If return_final_state is False: output tensor [batch, dim, seq_len] If return_final_state is True: (output, final_state) tuple """ batch_size, dim, seq_len = x.shape state_len = self.kernel_size[0] - 1 # Edge case: seq_len==1 with return_final_state # CUDA kernel limitation: return_final_states requires channel-last layout, # which is impossible when seq_len==1. Handle via update() with zero-init state. if return_final_state and seq_len == 1: # Initialize zero state if none provided, with channel-last layout for CUDA kernel if conv_state is None: # Create channel-last state: stride(1) == 1 conv_state = x.new_zeros(batch_size, state_len, dim).transpose(1, 2) # Use update() which handles single tokens efficiently out = _causal_conv1d_update( x.squeeze(2), # [batch, dim, 1] -> [batch, dim] conv_state, self._weight, bias=self.bias, activation=self._activation, ) return out.unsqueeze(2), conv_state # [batch, dim, 1], updated state # Standard CUDA path if return_final_state: # causal_conv1d requires channel-last layout for returning final states. # Channel-last means: stride(1)==1 AND stride(2)==dim (channels are contiguous). # For shape [batch, dim, seq], standard contiguous is (dim*seq, seq, 1). # Channel-last is (dim*seq, 1, dim) - achieved via transpose+contiguous+transpose. if x.stride(1) != 1 or x.stride(2) < dim: x = x.transpose(1, 2).contiguous().transpose(1, 2) # Allocate final state buffer with correct memory layout # causal_conv1d requires final_states.stride(1) == 1 final_state = x.new_zeros(batch_size, state_len, dim).transpose(1, 2) else: final_state = None out = _causal_conv1d_fn( x, self._weight, bias=self.bias, initial_states=conv_state, return_final_states=return_final_state, final_states_out=final_state, activation=self._activation, ) if return_final_state: if isinstance(out, tuple): out, final_state = out # final_state has shape [batch, dim, state_len] with channel-last strides # Ensure it's safe for in-place updates by subsequent CUDA kernel calls assert final_state is not None if final_state.stride(1) == 1: # Make a copy that's safe to modify in-place final_state = final_state.clone() return out, final_state return out def update( self, x: torch.Tensor, conv_state: torch.Tensor, ) -> torch.Tensor: """ Single-token decode step using cached conv state. Args: x: Input tensor [batch, dim] (single token) conv_state: Conv state [batch, dim, kernel_size-1], will be updated in-place Returns: Output tensor [batch, dim] """ return _causal_conv1d_update( x, conv_state, self._weight, bias=self.bias, activation=self._activation, ) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @torch.compile def segsum(x): T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum @torch.compile def materialize_mixer(A_log, B, C, D): batch_size, length, n_heads, d_state = B.shape assert A_log.shape == (batch_size, length, n_heads) assert B.shape == C.shape == (batch_size, length, n_heads, d_state) A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") powers = torch.exp(segsum(A_log)) T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) if D is not None: T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) T = rearrange(T, "b h z l -> b h l z") return T def apply_mask_to_padding_states(hidden_states, attention_mask): if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) return hidden_states class Apriel2Attention(nn.Module): """Multi-headed attention with support for GQA and configurable causality. Config options (Fast-LLM naming): heads: Number of query heads head_groups: Number of key/value heads (for GQA) head_size: Dimension per head add_linear_biases: Whether to use biases in projections causal: Whether to use causal masking window_size: Optional sliding window size rotary: Rotary embedding config dict """ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config): super().__init__() self.config = config self.mixer_config = mixer_config self.layer_idx = layer_idx # Extract config using Fast-LLM naming self.num_heads = mixer_config["heads"] self.num_key_value_heads = mixer_config.get("head_groups", self.num_heads) self.head_dim = mixer_config["head_size"] self.hidden_size = d_model self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.is_causal = mixer_config.get("causal", True) self.window_size = mixer_config.get("window_size") # cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images) self.cross_document_attention = mixer_config.get("cross_document_attention", True) # Bias configuration mirroring Fast-LLM's structure: # - add_linear_biases: bool (default for all projections) # - query_layer: {"bias": {"enabled": bool}} (per-layer override) # - key_layer: {"bias": {"enabled": bool}} # - value_layer: {"bias": {"enabled": bool}} # - dense_layer: {"bias": {"enabled": bool}} default_bias = mixer_config.get("add_linear_biases", False) def get_layer_bias(layer_name: str) -> bool: layer_cfg = mixer_config.get(layer_name, {}) bias_cfg = layer_cfg.get("bias", {}) enabled = bias_cfg.get("enabled") return default_bias if enabled is None else enabled q_bias = get_layer_bias("query_layer") k_bias = get_layer_bias("key_layer") v_bias = get_layer_bias("value_layer") o_bias = get_layer_bias("dense_layer") # Projections self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias) self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias) self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias) @classmethod def setup( cls, mixer_config: dict, hidden_size: int, max_position_embeddings: int, ) -> nn.ModuleDict: """ Setup resources needed by this mixer (rotary embeddings). Called once per block type, before instances are created. Args: mixer_config: Mixer configuration dict hidden_size: Model hidden size max_position_embeddings: Maximum sequence length Returns: ModuleDict containing 'rotary_emb' """ rotary_config_dict = mixer_config["rotary"] rotary_type = rotary_config_dict["type"] rope_theta = rotary_config_dict["theta"] num_heads = mixer_config["heads"] head_dim = mixer_config["head_size"] if rotary_type == "pixtral_2d": from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding rotary_config = SimpleNamespace( head_dim=head_dim, rope_theta=rope_theta, image_size=rotary_config_dict["max_image_size"], patch_size=rotary_config_dict["patch_size"], ) return nn.ModuleDict({"rotary_emb": PixtralRotaryEmbedding(config=rotary_config)}) elif rotary_type == "mistral_1d": from transformers.models.mistral.modeling_mistral import MistralRotaryEmbedding rotary_config = SimpleNamespace( max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, head_dim=head_dim, hidden_size=hidden_size, num_attention_heads=num_heads, partial_rotary_factor=1.0, ) return nn.ModuleDict({"rotary_emb": MistralRotaryEmbedding(config=rotary_config)}) else: raise ValueError(f"Unknown rotary type: {rotary_type}") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple] = None, past_key_values: Optional[Any] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # Select attention implementation attention_interface = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0, scaling=self.scaling, sliding_window=self.window_size, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights def preprocess( self, hidden_states: torch.Tensor, resources: Optional[nn.ModuleDict], **kwargs: Unpack[BlockSequenceKwargs], ) -> PreprocessingOutput: """ Compute attention preprocessing: position embeddings and masks. Args: hidden_states: Current hidden states (for shape/device) resources: ModuleDict of resources from setup() (contains 'rotary_emb') **kwargs: Metadata including: - position_ids: Position IDs for rotary embedding - sequence_lengths: [n1, n2, ...] for sequence isolation - attention_mask, cache_position, past_key_values, etc. Returns: PreprocessingOutput with position_embeddings, attention_mask, and flash_attn_kwargs """ position_ids = kwargs.get("position_ids") # Compute position embeddings using rotary_emb from resources position_embeddings = None if resources is not None and "rotary_emb" in resources and position_ids is not None: rotary_emb = resources["rotary_emb"] cos, sin = rotary_emb(hidden_states, position_ids) position_embeddings = (cos, sin) # Handle sequence isolation (cross_document_attention=False) sequence_lengths = kwargs.get("sequence_lengths") flash_attn_kwargs = {} mask = kwargs.get("attention_mask") if not self.cross_document_attention and sequence_lengths is not None: # Compute cu_seqlens for flash attention or block diagonal mask for others attn_impl = getattr(self.config, "_attn_implementation", "eager") if attn_impl == "flash_attention_2": # Flash attention: use cu_seqlens for varlen attention cu_seqlens = torch.tensor( [0] + list(torch.cumsum(torch.tensor(sequence_lengths), dim=0).tolist()), dtype=torch.int32, device=hidden_states.device, ) max_seqlen = max(sequence_lengths) flash_attn_kwargs = { "cu_seq_lens_q": cu_seqlens, "cu_seq_lens_k": cu_seqlens, "max_length_q": max_seqlen, "max_length_k": max_seqlen, } mask = None # Flash varlen doesn't use attention_mask else: # Non-flash: use block diagonal mask mask = _generate_block_attention_mask(sequence_lengths, hidden_states) elif self.is_causal and kwargs.get("cache_position") is not None: # Causal attention - compute causal mask mask_function = create_causal_mask if self.window_size is None else create_sliding_window_causal_mask # Build config for mask creation mask_config = SimpleNamespace( hidden_size=self.config.hidden_size, num_attention_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads, head_dim=self.head_dim, max_position_embeddings=self.config.embeddings["max_position_embeddings"], sliding_window=self.window_size, _attn_implementation=getattr(self.config, "_attn_implementation", "eager"), ) mask = mask_function( config=mask_config, input_embeds=hidden_states, attention_mask=kwargs.get("attention_mask"), cache_position=kwargs["cache_position"], past_key_values=kwargs.get("past_key_values"), position_ids=position_ids, ) # Return computed tensors return { "position_embeddings": position_embeddings, "attention_mask": mask, **flash_attn_kwargs, } # Shared helper functions for both text and vision models def get_mixer_class(mixer_type: str) -> type: """Map mixer type string to mixer class.""" if mixer_type == "attention": return Apriel2Attention elif mixer_type == "mamba": return Apriel2Mamba elif mixer_type == "gdn": return Apriel2GatedDeltaNet elif mixer_type == "kda": return KimiDeltaAttention elif mixer_type == "stochastic": return Apriel2StochasticMixer else: raise ValueError(f"Unknown mixer type: {mixer_type}") def create_mixer(mixer_config: dict, hidden_size: int, layer_idx: int, config, allow_stochastic: bool = True): """Create a mixer instance from config. Uses get_mixer_class() for type→class mapping.""" # TODO: make constructor signatures uniform across mixer types and remove this function mixer_type = mixer_config.get("type", "attention") mixer_class = get_mixer_class(mixer_type) # Handles unknown types # Different mixer types have different constructor signatures if mixer_type == "attention": return mixer_class(hidden_size, mixer_config, layer_idx, config) elif mixer_type == "stochastic": if not allow_stochastic: raise ValueError("Stochastic mixers cannot contain nested stochastic mixers") return mixer_class(mixer_config, config, layer_idx) else: # mamba, gdn, kda all have same signature return mixer_class(hidden_size, mixer_config, layer_idx=layer_idx) class Apriel2PatternMixerAdapter(nn.Module): """Adapter that wraps a single mixer under mixers.{name} to match supernet weight paths. The supernet checkpoint stores weights as blocks.{i}.mixer.mixers.{type}.{param}, but a bare mixer creates blocks.{i}.mixer.{param}. This adapter adds the intermediate mixers.{name} level so pattern configs can load from supernet checkpoints. """ def __init__(self, mixer_name: str, mixer: nn.Module): super().__init__() self.mixers = nn.ModuleDict({mixer_name: mixer}) self._mixer_name = mixer_name def forward(self, *args, **kwargs): return self.mixers[self._mixer_name](*args, **kwargs) def preprocess(self, *args, **kwargs): return self.mixers[self._mixer_name].preprocess(*args, **kwargs) @classmethod def setup(cls, mixer_name: str, mixer_config: dict, hidden_size: int, max_position_embeddings: int) -> nn.ModuleDict: mixer_type = mixer_config.get("type", "attention") mixer_class = get_mixer_class(mixer_type) return mixer_class.setup(mixer_config, hidden_size, max_position_embeddings) class Apriel2Mamba(nn.Module): """Mamba mixer.""" def __init__( self, d_model, config_dict: dict, layer_idx=None, device=None, dtype=None, ): """Initialize Mamba from a config dictionary.""" factory_kwargs = {"device": device, "dtype": dtype} super().__init__() # Extract parameters from config dict d_state = config_dict.get("state_size", 16) d_inner = config_dict.get("d_inner") d_xb = config_dict.get("d_xb", None) d_conv = config_dict.get("d_conv", 4) expand = config_dict.get("expand", 2) dt_rank = config_dict.get("dt_rank", "auto") dt_min = config_dict.get("dt_min", 0.001) dt_max = config_dict.get("dt_max", 0.1) dt_init = config_dict.get("dt_init", "random") dt_scale = config_dict.get("dt_scale", 1.0) dt_init_floor = config_dict.get("dt_init_floor", 1e-4) repeat_kv_before_conv = config_dict.get("repeat_kv_before_conv", True) conv_bias = config_dict["conv_bias"] bias = config_dict.get("add_linear_biases", False) dt_proj_bias = config_dict["dt_proj_bias"] self.d_model = d_model self.d_xb = d_xb if d_xb is not None else d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = d_inner if d_inner is not None else int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = True self.layer_idx = layer_idx self.repeat_kv_before_conv = repeat_kv_before_conv self.activation = "silu" # Hardcoded for Mamba if self.repeat_kv_before_conv: self.conv1d = CausalConv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, activation=self.activation, **factory_kwargs, ) else: self.conv1d = CausalConv1d( in_channels=self.d_xb, out_channels=self.d_xb, bias=conv_bias, kernel_size=d_conv, groups=self.d_xb, activation=self.activation, **factory_kwargs, ) self.num_xb_head = self.d_xb // self.d_state self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = self.dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max if self.dt_proj.bias is not None: dt = torch.exp( torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) self.dt_proj.bias._no_reinit = True # S4D real initialization A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner, ).contiguous() A_log = torch.log(A) self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) self.D._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) def forward( self, hidden_states: torch.Tensor, past_key_values=None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ): """Forward pass for Mamba.""" # Check for CUDA when using fast path if is_fast_path_available and "cuda" not in self.in_proj.weight.device.type: raise RuntimeError( "Mamba with CUDA kernels requires CUDA device. Current device: " + str(self.in_proj.weight.device) ) cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape ssm_state, conv_state = None, None use_precomputed_states = False seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0 use_precomputed_states = ( past_key_values is not None and isinstance(past_key_values, Apriel2Cache) and past_key_values.conv_states[self.layer_idx] is not None and seqlen == 1 and past_key_values.conv_states[self.layer_idx].shape[0] == past_key_values.recurrent_states[self.layer_idx].shape[0] == batch and cache_position is not None and seqlen_offset > 0 ) ssm_state, conv_state = self._get_states_from_cache(past_key_values, batch) # Adaptive mode selection: use step() for single-token generation # This provides significant speedup during autoregressive decoding if use_precomputed_states: out, _, _ = self.step(hidden_states, conv_state, ssm_state) return (out,) A = -torch.exp(self.A_log.float()) zxbc = self.in_proj(hidden_states) z, x, B, C = torch.split( zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1, ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") B = rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) B = repeat_kv(B, self.repeat_group) B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() dt = self.dt_proj(self.dt_in_proj(hidden_states)) dt = rearrange(dt, "b l d -> b d l") if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") # Compute short convolution if conv_state is not None: # Store padded input for future decode steps (convention: state size = d_conv) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) x = self.conv1d(x) if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=self.dt_proj.bias.float() if self.dt_proj.bias is not None else None, delta_softplus=True, return_last_state=(ssm_state is not None), ) if ssm_state is not None: y, last_state = y ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return (out[:, :seqlen, :],) @classmethod def setup( cls, mixer_config: dict, hidden_size: int, max_position_embeddings: int, ) -> nn.ModuleDict: """Mamba has no setup resources - returns empty ModuleDict.""" return nn.ModuleDict() def preprocess( self, hidden_states: torch.Tensor, resources: Optional[nn.ModuleDict], **kwargs: Unpack[BlockSequenceKwargs], ) -> PreprocessingOutput: """Mamba has no preprocessing - returns empty dict.""" return {} def step(self, hidden_states, conv_state, ssm_state): hidden_states.dtype assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" hidden_states_input = hidden_states.squeeze(1) A = -torch.exp(self.A_log.float()) zxbc = self.in_proj(hidden_states_input) z, x, B, C = torch.split( zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1, ) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) x = rearrange(x, "b n_group dstate -> b (n_group dstate)") # Conv step x = self.conv1d.update(x, conv_state) if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) x = rearrange(x, "b n_group dstate -> b (n_group dstate)") x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head) dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head) A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head) D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head) z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head) dt_bias = ( rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head) if self.dt_proj.bias is not None else None ) # SSM step y = selective_state_update(ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True) y = rearrange(y, "b h d -> b (h d)") out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype if self.repeat_kv_before_conv: conv_state = torch.zeros(batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype) else: conv_state = torch.zeros(batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype) ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype ssm_state = torch.zeros( batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=device, dtype=ssm_dtype ) return conv_state, ssm_state def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): assert self.layer_idx is not None if inference_params is None or not isinstance(inference_params, Apriel2Cache): return None, None if inference_params.conv_states[self.layer_idx] is None: conv_state, ssm_state = self.allocate_inference_cache(batch_size, max_seqlen=0) inference_params.conv_states[self.layer_idx] = conv_state inference_params.recurrent_states[self.layer_idx] = ssm_state ssm_state = inference_params.recurrent_states[self.layer_idx] conv_state = inference_params.conv_states[self.layer_idx] if initialize_states: ssm_state.zero_() conv_state.zero_() return ssm_state, conv_state def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: """L2 normalization matching Fast-LLM's implementation.""" return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) class GatedRMSNormalization(nn.Module): """ Gated RMS normalization layer matching Fast-LLM's implementation. Uses fla.modules.fused_norm_gate.rms_norm_gated (required). Args: hidden_size: Size of the hidden dimension eps: Epsilon for numerical stability activation: Gating activation function ("silu" or "sigmoid") """ def __init__(self, hidden_size: int, eps: float = 1e-5, activation: str = "silu"): super().__init__() if rms_norm_gated is None: raise ImportError( "GatedRMSNormalization requires rms_norm_gated from fla library. " "Install with: pip install fla-core" ) self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps self.activation = activation def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: return rms_norm_gated( input_, gate, self.weight, None, activation=self.activation, eps=self.eps, residual=None, prenorm=False, residual_in_fp32=False, ) class Apriel2GatedDeltaNet(nn.Module): """ Gated Delta Net implementation matching Fast-LLM's gdn.py exactly. Weight names and config parameters match Fast-LLM: - in_proj_qkvz, in_proj_ba, convolution, out_proj, dt_bias, A_log, norm - value_heads, key_heads, key_head_dim, value_head_dim Uses Fast-LLM's flat QKVZ layout: [Q_all | K_all | V_all | Z_all] Uses fla.ops.gated_delta_rule.chunk_gated_delta_rule when available. """ def __init__( self, d_model, config_dict: dict, layer_idx=None, device=None, dtype=None, ): super().__init__() self.layer_idx = layer_idx self.hidden_size = d_model # Config params - match Fast-LLM naming (value_heads, key_heads, etc.) self.activation = config_dict["convolution_layer"].get("activation", "silu") self.value_heads = config_dict.get("value_heads", 32) self.key_heads = config_dict.get("key_heads", 8) self.key_head_dim = config_dict.get("key_head_dim", 64) self.value_head_dim = config_dict.get("value_head_dim", 64) self.conv_kernel_size = config_dict["convolution_layer"]["kernel_size"] self.norm_eps = config_dict.get("norm_eps", 1e-5) # Derived dimensions self.key_dim = self.key_head_dim * self.key_heads self.value_dim = self.value_head_dim * self.value_heads self.conv_dim = self.key_dim * 2 + self.value_dim # Q, K, V (no Z in conv) self.qkvz_dim = self.key_dim * 2 + self.value_dim * 2 # Q, K, V, Z self.value_heads_per_key = self.value_heads // self.key_heads # Projection layers - names match Fast-LLM exactly self.in_proj_qkvz = nn.Linear(d_model, self.qkvz_dim, bias=False, device=device, dtype=dtype) self.in_proj_ba = nn.Linear(d_model, self.value_heads * 2, bias=False, device=device, dtype=dtype) self.out_proj = nn.Linear(self.value_dim, d_model, bias=False, device=device, dtype=dtype) # Convolution - named 'convolution' to match Fast-LLM self.convolution = CausalConv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, bias=False, kernel_size=self.conv_kernel_size, groups=self.conv_dim, activation=self.activation, device=device, dtype=dtype, ) # Learnable parameters - match Fast-LLM initialization self.dt_bias = nn.Parameter(torch.ones(self.value_heads, device=device, dtype=dtype)) self.A_log = nn.Parameter(torch.zeros(self.value_heads, device=device, dtype=dtype).uniform_(0, 16).log()) # Normalization layer - named 'norm' with 'weight' param to match Fast-LLM self.norm = GatedRMSNormalization(self.value_head_dim, eps=self.norm_eps) # Require FLA kernels - no silent fallback to unoptimized code paths if chunk_gated_delta_rule is None or fused_recurrent_gated_delta_rule is None: raise ImportError( "GatedDeltaNet requires the fla library for optimized kernels. " "Install with: pip install fla-core" ) def _fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor): """ Split QKVZ and BA tensors using Fast-LLM's flat layout. Fast-LLM layout: [Q_all_heads | K_all_heads | V_all_heads | Z_all_heads] """ # Split QKVZ - flat layout matching Fast-LLM qkv_sizes = ( self.key_dim, # Q: key_heads * key_head_dim self.key_dim, # K: key_heads * key_head_dim self.value_dim, # V: value_heads * value_head_dim self.value_dim, # Z: value_heads * value_head_dim ) query, key, value, z = torch.split(mixed_qkvz, qkv_sizes, dim=-1) # Reshape to head format: [batch, seq, heads, head_dim] query = query.reshape(*query.shape[:-1], self.key_heads, self.key_head_dim) key = key.reshape(*key.shape[:-1], self.key_heads, self.key_head_dim) value = value.reshape(*value.shape[:-1], self.value_heads, self.value_head_dim) z = z.reshape(*z.shape[:-1], self.value_heads, self.value_head_dim) # Split BA - flat layout: [beta_all | alpha_all] beta, alpha = torch.split(mixed_ba, (self.value_heads, self.value_heads), dim=-1) return query, key, value, z, beta, alpha def _ensure_cache_initialized(self, past_key_values, batch_size, device, dtype): """Initialize cache if it doesn't exist for this layer.""" if past_key_values is None: return if past_key_values.conv_states[self.layer_idx] is None: conv_state = torch.zeros(batch_size, self.conv_dim, self.conv_kernel_size, device=device, dtype=dtype) past_key_values.conv_states[self.layer_idx] = conv_state if past_key_values.recurrent_states[self.layer_idx] is None: recurrent_state = torch.zeros( batch_size, self.value_heads, self.key_head_dim, self.value_head_dim, device=device, dtype=dtype ) past_key_values.recurrent_states[self.layer_idx] = recurrent_state def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_mask=None, **kwargs): cache_position = kwargs.get("cache_position", None) batch_size, seq_len, _ = hidden_states.shape # Get conv and recurrent state from cache if available conv_state = None recurrent_state = None if past_key_values is not None: conv_state = past_key_values.conv_states[self.layer_idx] recurrent_state = past_key_values.recurrent_states[self.layer_idx] # Check if using precomputed states (single token decode with cache) # Must check that conv_state exists for THIS layer (not just overall has_previous_state) use_precomputed_states = ( past_key_values is not None and conv_state is not None and seq_len == 1 and cache_position is not None ) # Project to QKVZ and BA mixed_qkvz = self.in_proj_qkvz(hidden_states) mixed_ba = self.in_proj_ba(hidden_states) # Split into components using Fast-LLM's flat layout query, key, value, z, beta, alpha = self._fix_query_key_value_ordering(mixed_qkvz, mixed_ba) # Flatten QKV for convolution (no Z in conv) query_flat = query.reshape(batch_size, seq_len, -1) key_flat = key.reshape(batch_size, seq_len, -1) value_flat = value.reshape(batch_size, seq_len, -1) mixed_qkv = torch.cat([query_flat, key_flat, value_flat], dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) # [batch, conv_dim, seq] # Apply causal convolution if use_precomputed_states: # Single token decode - use cached conv state mixed_qkv = self.convolution.update( mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim] conv_state, ).unsqueeze( 2 ) # [batch, conv_dim] -> [batch, conv_dim, 1] else: # Prefill mode use_cache = past_key_values is not None if use_cache: mixed_qkv, final_state = self.convolution(mixed_qkv, return_final_state=True) past_key_values.conv_states[self.layer_idx] = final_state else: mixed_qkv = self.convolution(mixed_qkv) mixed_qkv = mixed_qkv.transpose(1, 2) # [batch, seq, conv_dim] # Split back after convolution query_flat, key_flat, value_flat = torch.split(mixed_qkv, (self.key_dim, self.key_dim, self.value_dim), dim=-1) query = query_flat.reshape(batch_size, seq_len, self.key_heads, self.key_head_dim) key = key_flat.reshape(batch_size, seq_len, self.key_heads, self.key_head_dim) value = value_flat.reshape(batch_size, seq_len, self.value_heads, self.value_head_dim) # Compute gating - match Fast-LLM exactly beta_gate = beta.sigmoid() g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) # Expand K heads to V heads if grouped query attention if self.value_heads_per_key > 1: query = query.repeat_interleave(self.value_heads_per_key, dim=2) key = key.repeat_interleave(self.value_heads_per_key, dim=2) # Run gated delta rule (FLA kernels required) if not use_precomputed_states: # Chunked mode for prefill output, last_recurrent_state = chunk_gated_delta_rule( query, key, value, g=g, beta=beta_gate, initial_state=recurrent_state, output_final_state=past_key_values is not None, use_qk_l2norm_in_kernel=True, ) # Ensure state is in same dtype as hidden_states (fla kernel may return float32) if last_recurrent_state is not None: last_recurrent_state = last_recurrent_state.to(hidden_states.dtype) else: # Recurrent mode for single token decode output, last_recurrent_state = fused_recurrent_gated_delta_rule( query, key, value, g=g, beta=beta_gate, initial_state=recurrent_state, output_final_state=past_key_values is not None, use_qk_l2norm_in_kernel=True, ) # Update recurrent state in cache if past_key_values is not None: past_key_values.recurrent_states[self.layer_idx] = last_recurrent_state # Apply gated normalization z_shape_og = z.shape output = output.reshape(-1, output.shape[-1]) z_flat = z.reshape(-1, z.shape[-1]) output = self.norm(output, z_flat) output = output.reshape(z_shape_og) output = output.reshape(output.shape[0], output.shape[1], -1) # Output projection output = self.out_proj(output) return (output,) @classmethod def setup( cls, mixer_config: dict, hidden_size: int, max_position_embeddings: int, ) -> nn.ModuleDict: """GatedDeltaNet has no setup resources - returns empty ModuleDict.""" return nn.ModuleDict() def preprocess( self, hidden_states: torch.Tensor, resources: Optional[nn.ModuleDict], **kwargs: Unpack[BlockSequenceKwargs], ) -> PreprocessingOutput: """GatedDeltaNet has no preprocessing - returns empty dict.""" return {} class KimiDeltaAttention(nn.Module): """ Kimi Delta Attention (KDA) implementation matching Fast-LLM's kda.py. Weight names match Fast-LLM: - q_proj, k_proj, v_proj, o_proj - main projections - f_a_proj, f_b_proj - gate kernel (low-rank) - g_a_proj, g_b_proj - output gate (low-rank) - beta_proj - beta gating - q_conv, k_conv, v_conv - CausalConv1d modules - A_log, dt_bias - learnable parameters - norm - gated RMS normalization Uses fla.ops.kda.chunk_kda and fused_recurrent_kda kernels. Uses CausalConv1d for convolutions (requires causal_conv1d CUDA kernels). """ def __init__( self, d_model, config_dict: dict, layer_idx=None, device=None, dtype=None, ): super().__init__() if chunk_kda is None or fused_kda_gate is None: raise ImportError( "KimiDeltaAttention requires the `fla` package. " "Please install it with `pip install -U fla-core`." ) self.layer_idx = layer_idx self.hidden_size = d_model self.mode = "chunk" # Config params - match Fast-LLM naming self.num_heads = config_dict.get("heads", 32) self.head_dim = config_dict.get("head_dim", 64) conv_config = config_dict.get("convolution_layer", {}) self.conv_kernel_size = conv_config.get("kernel_size", 4) norm_config = config_dict.get("normalization", {}) self.norm_eps = norm_config.get("epsilon", 1e-5) self.norm_activation = norm_config.get( "activation", "silu" ) # default to silu to be consistent with Fast-LLM's default. Note, Kimi uses sigmoid. # Derived dimensions self.projection_size = self.head_dim * self.num_heads # Projection layers - names match Fast-LLM exactly self.q_proj = nn.Linear(d_model, self.projection_size, bias=False, device=device, dtype=dtype) self.k_proj = nn.Linear(d_model, self.projection_size, bias=False, device=device, dtype=dtype) self.v_proj = nn.Linear(d_model, self.projection_size, bias=False, device=device, dtype=dtype) # Convolutions - use CausalConv1d for proper left-only padding # Named to match Fast-LLM (q_conv, k_conv, v_conv) self.q_conv = CausalConv1d( in_channels=self.projection_size, out_channels=self.projection_size, kernel_size=self.conv_kernel_size, groups=self.projection_size, # depthwise bias=False, activation="silu", device=device, dtype=dtype, ) self.k_conv = CausalConv1d( in_channels=self.projection_size, out_channels=self.projection_size, kernel_size=self.conv_kernel_size, groups=self.projection_size, bias=False, activation="silu", device=device, dtype=dtype, ) self.v_conv = CausalConv1d( in_channels=self.projection_size, out_channels=self.projection_size, kernel_size=self.conv_kernel_size, groups=self.projection_size, bias=False, activation="silu", device=device, dtype=dtype, ) # Gate kernel projections (low-rank: hidden -> head_dim -> projection) self.f_a_proj = nn.Linear(d_model, self.head_dim, bias=False, device=device, dtype=dtype) self.f_b_proj = nn.Linear(self.head_dim, self.projection_size, bias=False, device=device, dtype=dtype) # Output gate projections (low-rank) self.g_a_proj = nn.Linear(d_model, self.head_dim, bias=False, device=device, dtype=dtype) self.g_b_proj = nn.Linear(self.head_dim, self.projection_size, bias=False, device=device, dtype=dtype) # Beta projection - named beta_proj to match Fast-LLM (not b_proj) self.beta_proj = nn.Linear(d_model, self.num_heads, bias=False, device=device, dtype=dtype) # Output projection self.o_proj = nn.Linear(self.projection_size, d_model, bias=False, device=device, dtype=dtype) # Learnable parameters - match Fast-LLM shapes # A_log: 1D shape (num_heads,) to match Fast-LLM self.A_log = nn.Parameter( torch.zeros(self.num_heads, device=device, dtype=torch.float32).uniform_(1, 16).log() ) self.dt_bias = nn.Parameter(torch.ones(self.projection_size, device=device, dtype=torch.float32)) # Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation) self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation) def _apply_conv(self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool): """ Apply causal convolution with cache support. Args: x: Input tensor [batch, seq, dim] conv: CausalConv1d module conv_state: Previous conv state [batch, dim, kernel_size-1] or None use_cache: Whether to output final state for caching Returns: (output, new_conv_state) tuple """ seq_len = x.shape[1] x = x.transpose(1, 2) # [batch, dim, seq] # Single token decode with existing cache if conv_state is not None and seq_len == 1: out = conv.update(x.squeeze(2), conv_state) return out.unsqueeze(1), conv_state # [batch, 1, dim] # Prefill mode if use_cache: out, final_state = conv(x, conv_state=conv_state, return_final_state=True) else: out = conv(x, conv_state=conv_state) final_state = None return out.transpose(1, 2), final_state # [batch, seq, dim] def forward( self, hidden_states: torch.Tensor, past_key_values=None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ): batch_size, seq_len, _ = hidden_states.shape mode = "fused_recurrent" if (seq_len <= 64 and not self.training) else self.mode # Get cache states if available conv_state_q, conv_state_k, conv_state_v = None, None, None recurrent_state = None use_cache = past_key_values is not None if past_key_values is not None: conv_states = past_key_values.conv_states[self.layer_idx] if conv_states is not None: conv_state_q, conv_state_k, conv_state_v = conv_states recurrent_state = past_key_values.recurrent_states[self.layer_idx] # Project Q, K, V and apply convolutions q, conv_state_q = self._apply_conv(self.q_proj(hidden_states), self.q_conv, conv_state_q, use_cache) k, conv_state_k = self._apply_conv(self.k_proj(hidden_states), self.k_conv, conv_state_k, use_cache) v, conv_state_v = self._apply_conv(self.v_proj(hidden_states), self.v_conv, conv_state_v, use_cache) # Gate kernel computation (raw g, gate applied inside kernel for chunk mode) g = self.f_b_proj(self.f_a_proj(hidden_states)) g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) # Beta gating beta = self.beta_proj(hidden_states).float().sigmoid() # Reshape Q, K, V to head format q, k = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), (q, k)) v = rearrange(v, "... (h d) -> ... h d", d=self.head_dim) # Run KDA kernel if mode == "chunk": # For chunk mode: gate computed inside kernel (matches FLA reference) o, recurrent_state = chunk_kda( q=q, k=k, v=v, g=g, beta=beta, A_log=self.A_log, dt_bias=self.dt_bias, initial_state=recurrent_state, output_final_state=past_key_values is not None, use_qk_l2norm_in_kernel=True, use_gate_in_kernel=True, ) else: # For fused_recurrent mode: pre-compute gate (matches FLA reference) g = fused_kda_gate(g, self.A_log.float(), dt_bias=self.dt_bias) o, recurrent_state = fused_recurrent_kda( q=q, k=k, v=v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=True, use_qk_l2norm_in_kernel=True, ) # Update cache if past_key_values is not None: past_key_values.recurrent_states[self.layer_idx] = recurrent_state past_key_values.conv_states[self.layer_idx] = (conv_state_q, conv_state_k, conv_state_v) # Output gating and normalization g_out = self.g_b_proj(self.g_a_proj(hidden_states)) g_out = rearrange(g_out, "... (h d) -> ... h d", d=self.head_dim) # Flatten for normalization, then reshape back o_shape = o.shape o = self.norm(o.reshape(-1, o.shape[-1]), g_out.reshape(-1, g_out.shape[-1])) o = o.reshape(o_shape) # Reshape and project output o = rearrange(o, "b t h d -> b t (h d)") o = self.o_proj(o) return (o,) @classmethod def setup( cls, mixer_config: dict, hidden_size: int, max_position_embeddings: int, ) -> nn.ModuleDict: """KimiDeltaAttention has no setup resources - returns empty ModuleDict.""" return nn.ModuleDict() def preprocess( self, hidden_states: torch.Tensor, resources: Optional[nn.ModuleDict], **kwargs: Unpack[BlockSequenceKwargs], ) -> PreprocessingOutput: """KimiDeltaAttention has no preprocessing - returns empty dict.""" return {} class Apriel2BlockSequence(nn.Module): """ Block sequence abstraction - mirrors Fast-LLM's BlockSequence. Used by both text decoder and vision encoder. Architecture: - Pure container for blocks (handles fixed/pattern types) - Delegates resource setup to mixers via mixer.setup() - Owns mixer_resources (ModuleDict from setup, deduplicated by block_name) - Delegates preprocessing to mixers via mixer.preprocess() - Caches preprocessing per unique block type (efficient) - Completely agnostic to mixer types (attention, mamba, etc.) Setup + Preprocessing pattern: 1. Call mixer.setup() for each unique block type → collect resources (rotary_emb, etc.) 2. Call mixer.preprocess() for each unique block type → compute tensors 3. Cache preprocessing results indexed by block_name 4. Reuse cached preprocessing for blocks of same type 5. Merge preprocessing outputs into block kwargs """ def __init__( self, sequence_config: dict, hidden_size: int, max_position_embeddings: int, config: Apriel2TextConfig, ): super().__init__() self.sequence_config = sequence_config self.hidden_size = hidden_size self.max_position_embeddings = max_position_embeddings self.config = config # Build blocks (handles fixed/pattern) # NOTE: _build_blocks() calls classmethod setup() to create mixer_resources BEFORE instances self.blocks = self._build_blocks() # Extract unique mixer instances (one per unique block_name) for preprocessing self.unique_mixers: dict[str, nn.Module] = {} for layer_idx, block in enumerate(self.blocks): block_name = self.get_block_name(layer_idx) if block_name not in self.unique_mixers: self.unique_mixers[block_name] = block.mixer def _build_blocks(self) -> nn.ModuleList: """ Build blocks based on fixed/pattern type. Phase 1: Setup resources (called once per block type, before instances) Phase 2: Create block instances (resources already available) """ seq_type = self.sequence_config.get("type", "fixed") num_blocks = self.sequence_config.get("num_blocks") # PHASE 1: Setup resources BEFORE creating instances # Initialize mixer_resources container self.mixer_resources = nn.ModuleDict() # Extract unique block types and call setup for each if seq_type == "fixed": # Fixed: single block type repeated block_config = self.sequence_config.get("block", {}) mixer_config = block_config.get("mixer", {}) mixer_type = mixer_config.get("type", "attention") # Call classmethod setup mixer_class = get_mixer_class(mixer_type) resources = mixer_class.setup(mixer_config, self.hidden_size, self.max_position_embeddings) if len(resources) > 0: self.mixer_resources["block"] = resources elif seq_type == "pattern": # Pattern: multiple block types in repeating pattern blocks_config = self.sequence_config.get("blocks", {}) for block_name, block_config in blocks_config.items(): mixer_config = block_config.get("mixer", {}) mixer_type = mixer_config.get("type", "attention") # Call classmethod setup mixer_class = get_mixer_class(mixer_type) resources = mixer_class.setup(mixer_config, self.hidden_size, self.max_position_embeddings) if len(resources) > 0: self.mixer_resources[block_name] = resources else: raise ValueError(f"Unknown sequence type: {seq_type}") # PHASE 2: Create block instances (resources already set up) # Extract rms_norm_eps from config head.normalization.epsilon rms_norm_eps = self.config.head["normalization"]["epsilon"] blocks = [] for layer_idx in range(num_blocks): # Get block_config and block_name for this layer if seq_type == "fixed": block_config = self.sequence_config.get("block", {}) block_name_for_layer = None # No adapter needed for fixed type elif seq_type == "pattern": pattern = self.sequence_config.get("pattern", []) blocks_config = self.sequence_config.get("blocks", {}) block_name = pattern[layer_idx % len(pattern)] block_config = blocks_config[block_name] block_name_for_layer = block_name # Pass to Apriel2Block for weight path matching else: raise ValueError(f"Unknown sequence type: {seq_type}") # Create block with explicit parameters (no fake config creation!) blocks.append( Apriel2Block( block_config=block_config, hidden_size=self.hidden_size, layer_idx=layer_idx, rms_norm_eps=rms_norm_eps, config=self.config, block_name=block_name_for_layer, ) ) return nn.ModuleList(blocks) def get_block_name(self, layer_idx: int) -> str: """Get block name for a specific layer (shared logic).""" seq_type = self.sequence_config.get("type", "fixed") if seq_type == "fixed": return "block" elif seq_type == "pattern": pattern = self.sequence_config.get("pattern", []) return pattern[layer_idx % len(pattern)] else: raise ValueError(f"Unknown sequence type: {seq_type}") def preprocess( self, hidden_states: torch.Tensor, **kwargs: Unpack[BlockSequenceKwargs], ) -> dict[str, PreprocessingOutput]: """ Compute preprocessing for all unique block types. Aggregates preprocessing from all unique mixers. Args: hidden_states: Current hidden states (for shape/device) **kwargs: Metadata (position_ids, attention_mask, cache_position, etc.) Returns: Preprocessing cache keyed by block_name """ preprocessing_cache: dict[str, PreprocessingOutput] = {} for block_name, mixer in self.unique_mixers.items(): # Get resources for this block type (from setup) # Note: nn.ModuleDict doesn't have .get(), so we check membership first resources = self.mixer_resources[block_name] if block_name in self.mixer_resources else None # Mixer computes preprocessing using resources (read-only) # Returns PreprocessingOutput (position_embeddings, attention_mask, etc.) preprocessing_cache[block_name] = mixer.preprocess(hidden_states, resources, **kwargs) return preprocessing_cache def forward( self, hidden_states: torch.Tensor, **kwargs: Unpack[BlockSequenceKwargs], ) -> tuple[torch.Tensor, Optional[tuple], Optional[tuple]]: """ Forward pass through block sequence. Args: hidden_states: Input tensor (data) **kwargs: Metadata (attention_mask, position_ids, etc.) Returns: (hidden_states, all_hidden_states, all_attentions) """ # Compute preprocessing ONCE per unique block type # Delegates to self.preprocess() which aggregates from all mixers preprocessing_cache = self.preprocess(hidden_states, **kwargs) # Initialize output collections all_hidden_states = () if kwargs.get("output_hidden_states") else None all_attentions = () if kwargs.get("output_attentions") else None # Iterate through blocks - REUSE cached preprocessing for layer_idx, block in enumerate(self.blocks): # Collect intermediate hidden state if requested if all_hidden_states is not None: all_hidden_states += (hidden_states,) # Get preprocessing for this block type (reused for blocks of same type) block_name = self.get_block_name(layer_idx) preprocessing_kwargs = preprocessing_cache[block_name] # Merge input kwargs with preprocessing outputs # Preprocessing can override (e.g., causal mask overrides attention_mask) block_kwargs = {**kwargs, **preprocessing_kwargs} # Pipe through: y = f(x, **kwargs) # Block extracts what it needs from kwargs layer_outputs = block(hidden_states, **block_kwargs) hidden_states = layer_outputs[0] # Collect attention if requested if all_attentions is not None: all_attentions += (layer_outputs[1] if len(layer_outputs) > 1 else None,) return hidden_states, all_hidden_states, all_attentions class Apriel2Block(nn.Module): """ Transformer block with mixer (attention/mamba/etc) and MLP. Used for both text decoder and vision encoder. """ def __init__( self, block_config: dict, hidden_size: int, layer_idx: int, rms_norm_eps: float, config: Apriel2TextConfig, block_name: Optional[str] = None, ): """ Args: block_config: Dict with 'mixer', 'mlp', 'normalization' configs hidden_size: Model hidden size layer_idx: Layer index in the sequence rms_norm_eps: Epsilon for RMS normalization config: Model config (passed to mixers that need it) block_name: For pattern configs, the mixer name (e.g. "attention") to match supernet weight paths """ super().__init__() self.hidden_size = hidden_size self.layer_idx = layer_idx # Create mixer based on type mixer_config = block_config.get("mixer", {"type": "attention"}) raw_mixer = create_mixer(mixer_config, hidden_size, layer_idx, config, allow_stochastic=True) # For pattern configs, wrap in adapter to match supernet checkpoint weight paths if block_name is not None: self.mixer = Apriel2PatternMixerAdapter(block_name, raw_mixer) else: self.mixer = raw_mixer # Create MLP mlp_config = block_config.get("mlp", {"type": "mlp"}) self.mlp = self._create_mlp(mlp_config, hidden_size) # Create normalization layers norm_config = block_config.get("normalization", {"type": "rms_norm"}) self.input_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps) self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps) def _create_mlp(self, mlp_config: dict, hidden_size: int): """Create MLP based on config. Supports per-layer bias configuration mirroring Fast-LLM: - add_linear_biases: default bias setting for all layers - layer_1.bias.enabled: override for up_proj/gate_proj - layer_2.bias.enabled: override for down_proj """ mlp_type = mlp_config.get("type", "mlp") if mlp_type == "mlp": intermediate_size = mlp_config["intermediate_size"] activation = mlp_config.get("activation", "silu") gated = mlp_config.get("gated", False) # Per-layer bias configuration (mirrors Fast-LLM structure) default_bias = mlp_config.get("add_linear_biases", False) def get_layer_bias(layer_name: str) -> bool: layer_cfg = mlp_config.get(layer_name, {}) bias_cfg = layer_cfg.get("bias", {}) enabled = bias_cfg.get("enabled") return default_bias if enabled is None else enabled layer_1_bias = get_layer_bias("layer_1") layer_2_bias = get_layer_bias("layer_2") if gated: # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together) # For now, we use the default bias setting for gated MLPs # TODO: Add per-layer bias support to gated MLP mlp_cfg = SimpleNamespace( hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=activation, ) return MistralMLP(mlp_cfg) else: return SimpleMLP( hidden_size, intermediate_size, activation, layer_1_bias=layer_1_bias, layer_2_bias=layer_2_bias, ) else: raise ValueError(f"Unknown MLP type: {mlp_type}") def _create_norm(self, norm_config: dict, hidden_size: int, rms_norm_eps: float): """Create normalization layer based on config.""" norm_type = norm_config.get("type", "rms_norm") if norm_type == "rms_norm": return MistralRMSNorm(hidden_size, eps=rms_norm_eps) elif norm_type == "layer_norm": return nn.LayerNorm(hidden_size, eps=rms_norm_eps) else: raise ValueError(f"Unknown normalization type: {norm_type}") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Apriel2Cache] = None, output_attentions: bool = False, use_cache: bool = False, position_embeddings=None, **kwargs, ) -> tuple: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) mixer_outputs = self.mixer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) hidden_states = mixer_outputs[0] hidden_states = residual + hidden_states # MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (mixer_outputs[1],) if len(mixer_outputs) > 1 else (None,) if use_cache: outputs += (mixer_outputs[2] if len(mixer_outputs) > 2 else None,) return outputs class Apriel2StochasticMixer(nn.Module): """ Stochastic mixer that contains multiple mixer options. During training: randomly samples one mixer per forward pass During inference: uses the main_mixer """ def __init__(self, mixer_config: dict, config: Apriel2TextConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx # Get sub-mixer configs mixers_config = mixer_config.get("mixers", {}) self.main_mixer_name = mixer_config.get("main_mixer_name", list(mixers_config.keys())[0]) # Sampling strategy self.sampling_strategy = mixer_config.get("sampling_strategy", "uniform") sampling_weights = mixer_config.get("sampling_weights", None) # Create each sub-mixer self.mixers = nn.ModuleDict() for name, sub_mixer_config in mixers_config.items(): self.mixers[name] = create_mixer( sub_mixer_config, config.hidden_size, layer_idx, config, allow_stochastic=False ) # Set up sampling probabilities mixer_names = list(self.mixers.keys()) if self.sampling_strategy == "uniform": self._sampling_probs = [1.0 / len(self.mixers)] * len(self.mixers) elif self.sampling_strategy == "weighted": if sampling_weights is None: raise ValueError("sampling_weights must be provided when using weighted sampling strategy") # Normalize weights to sum to 1.0 total = sum(sampling_weights.get(name, 1.0) for name in mixer_names) self._sampling_probs = [sampling_weights.get(name, 1.0) / total for name in mixer_names] else: raise ValueError(f"Unknown sampling_strategy: {self.sampling_strategy}") self._mixer_names = mixer_names logger.info( f"Initialized Apriel2StochasticMixer at layer {layer_idx} with {len(self.mixers)} mixers: " f"{', '.join(mixer_names)} (main={self.main_mixer_name}, strategy={self.sampling_strategy})" ) def forward( self, hidden_states: torch.Tensor, attention_mask=None, position_embeddings: Optional[dict] = None, **kwargs ): # Sample mixer during training, use main_mixer during inference if self.training: mixer_name = random.choices(self._mixer_names, weights=self._sampling_probs)[0] else: mixer_name = self.main_mixer_name # Set active mixer in cache for proper state routing past_key_values = kwargs.get("past_key_values") if past_key_values is not None and hasattr(past_key_values, "set_active_mixer"): past_key_values.set_active_mixer(self.layer_idx, mixer_name) mixer = self.mixers[mixer_name] mixer_position_embeddings = position_embeddings.get(mixer_name) if position_embeddings else None mixer_attention_mask = attention_mask.get(mixer_name) if isinstance(attention_mask, dict) else attention_mask return mixer( hidden_states, attention_mask=mixer_attention_mask, position_embeddings=mixer_position_embeddings, **kwargs ) @classmethod def setup( cls, mixer_config: dict, hidden_size: int, max_position_embeddings: int, ) -> nn.ModuleDict: """ Setup resources for stochastic mixer with nested mixers. Called before instance creation, recursively calls setup on nested mixer classes. Returns a ModuleDict where each key is a nested mixer name and value is its setup ModuleDict. """ nested_resources = nn.ModuleDict() # Get nested mixers config mixers_config = mixer_config.get("mixers", {}) for mixer_name, sub_mixer_config in mixers_config.items(): # Get mixer class from type mixer_type = sub_mixer_config.get("type", "attention") mixer_class = get_mixer_class(mixer_type) # Call setup on nested mixer class mixer_resources = mixer_class.setup(sub_mixer_config, hidden_size, max_position_embeddings) if len(mixer_resources) > 0: nested_resources[mixer_name] = mixer_resources return nested_resources def preprocess( self, hidden_states: torch.Tensor, resources: Optional[nn.ModuleDict], **kwargs: Unpack[BlockSequenceKwargs], ) -> PreprocessingOutput: """ Preprocess for stochastic mixer with nested mixers. Returns a PreprocessingOutput where position_embeddings and attention_mask are dicts mapping nested mixer names to their respective values. """ nested_position_embeddings = {} nested_attention_masks = {} for mixer_name, nested_mixer in self.mixers.items(): # Get resources for this nested mixer (if resources is a ModuleDict of ModuleDicts) # Note: nn.ModuleDict doesn't have .get(), so we check membership first nested_resources = resources[mixer_name] if resources is not None and mixer_name in resources else None # Get preprocessing for nested mixer nested_output = nested_mixer.preprocess(hidden_states, nested_resources, **kwargs) # Extract position_embeddings (may be None for some mixer types) if nested_output.get("position_embeddings") is not None: nested_position_embeddings[mixer_name] = nested_output["position_embeddings"] # Extract attention_mask (may be None for SDPA, or float for eager) # We include it even if None to override the original long int mask if "attention_mask" in nested_output: nested_attention_masks[mixer_name] = nested_output["attention_mask"] # Return PreprocessingOutput with nested position_embeddings and attention_mask dicts return PreprocessingOutput( position_embeddings=nested_position_embeddings if nested_position_embeddings else None, attention_mask=nested_attention_masks if nested_attention_masks else None, ) class Apriel2PreTrainedModel(PreTrainedModel): config_class = Apriel2TextConfig base_model_prefix = "model" _no_split_modules = ["Apriel2Block"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = False _supports_static_cache = False _supports_attention_backend = True def _prepare_cache_for_generation( self, generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, *args ): if generation_config.use_cache is False: return model_kwargs["past_key_values"] = Apriel2Cache(config=self.config) def _init_weights(self, module): std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, MistralRMSNorm): module.weight.data.fill_(1.0) class Apriel2TextModel(Apriel2PreTrainedModel): """Apriel2 text-only base model (without LM head).""" def __init__(self, config: Apriel2TextConfig): super().__init__(config) self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size # Embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # Decoder block sequence (uses shared BlockSequence abstraction) # Causal behavior determined by mixer config (attention mixers have causal=True by default) self.decoder = Apriel2BlockSequence( sequence_config=config.decoder, hidden_size=config.hidden_size, max_position_embeddings=config.embeddings["max_position_embeddings"], config=config, ) # Final norm (epsilon from head.normalization config) self.norm = MistralRMSNorm(config.hidden_size, eps=config.head["normalization"]["epsilon"]) self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Apriel2Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = Apriel2Cache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # Forward through decoder block sequence (handles position embeddings, masks, and iteration) hidden_states, all_hidden_states, all_self_attns = self.decoder( inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, cache_position=cache_position, **flash_attn_kwargs, ) # Apply final normalization hidden_states = self.norm(hidden_states) # Add final hidden state if requested if output_hidden_states: all_hidden_states += (hidden_states,) next_decoder_cache = past_key_values if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" config_class = Apriel2Config _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: Apriel2TextConfig): super().__init__(config) self.model = Apriel2TextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Apriel2Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Forward through model outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift for next-token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Apriel2Embeddings(nn.Module): """Converts images to patch embeddings via 2D convolution.""" def __init__(self, vision_hidden_size: int, embeddings_config: dict): super().__init__() # Extract parameters from config dict patch_height = embeddings_config.get("patch_height", 16) patch_width = embeddings_config.get("patch_width", 16) input_channels = embeddings_config.get("input_channels", 3) # RGB # 2D convolution to create patch embeddings (internally named patch_embeddings to match Fast-LLM) self.patch_embeddings = nn.Conv2d( in_channels=input_channels, out_channels=vision_hidden_size, kernel_size=(patch_height, patch_width), stride=(patch_height, patch_width), bias=False, ) # Normalization layer norm_config = embeddings_config.get("normalization", {"type": "layer_norm"}) norm_type = norm_config.get("type", "layer_norm") norm_eps = norm_config.get("eps", 1e-5) if norm_type == "layer_norm": self.normalization = nn.LayerNorm(vision_hidden_size, eps=norm_eps) elif norm_type == "rms_norm": self.normalization = MistralRMSNorm(vision_hidden_size, eps=norm_eps) else: raise ValueError(f"Unknown normalization type: {norm_type}") def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Args: pixel_values: [batch, channels, height, width] Returns: patch_embeddings: [batch, num_patches, hidden_size] """ # Apply convolution: [batch, channels, height, width] -> [batch, hidden, num_patches_h, num_patches_w] x = self.patch_embeddings(pixel_values) # Flatten spatial dimensions: [batch, hidden, num_patches_h, num_patches_w] -> [batch, hidden, num_patches] batch_size, hidden_size, h, w = x.shape x = x.view(batch_size, hidden_size, h * w) # Transpose to sequence format: [batch, hidden, num_patches] -> [batch, num_patches, hidden] # NOTE: .contiguous() is required to match Pixtral's numerical behavior. # Pixtral concatenates patches before normalization, which makes the tensor contiguous. # Without this, RMSNorm produces slightly different results (~4.7e-7) due to # floating-point computation order differences on non-contiguous tensors. x = x.transpose(1, 2).contiguous() # Apply normalization x = self.normalization(x) return x def _generate_block_attention_mask( patch_counts: list[int], hidden_states: torch.Tensor, ) -> torch.Tensor: """Generate block diagonal attention mask to isolate images. Like Pixtral's generate_block_attention_mask: each image can only attend to its own patches, preventing cross-image attention. Args: patch_counts: List of patch counts per image [n1, n2, ...] hidden_states: Hidden states tensor for dtype/device [1, total_patches, hidden] Returns: attention_mask: [1, 1, total_patches, total_patches] with 0 for allowed, -inf for blocked """ dtype = hidden_states.dtype device = hidden_states.device seq_len = hidden_states.shape[1] d_min = torch.finfo(dtype).min # Start with all blocked mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) # Unblock each image's diagonal block block_end_idx = torch.tensor(patch_counts, device=device).cumsum(-1) block_start_idx = torch.cat([torch.tensor([0], device=device), block_end_idx[:-1]]) for start, end in zip(block_start_idx, block_end_idx): mask[start:end, start:end] = 0 return mask[None, None, :, :] def _compute_2d_position_ids( patch_embeds_list: list[torch.Tensor], max_patches_per_side: int, patch_size: int, ) -> torch.Tensor: """Compute 2D position IDs for concatenated patches. Like Pixtral's position_ids_in_meshgrid: computes position_id = h * max_width + w for each patch, then concatenates across all images. Args: patch_embeds_list: List of patch embeddings [patches_i, hidden] per image max_patches_per_side: Maximum patches per side for position encoding patch_size: Size of each patch Returns: position_ids: [total_patches] tensor of position IDs """ positions = [] for patch_embed in patch_embeds_list: # Infer grid dimensions from number of patches # This assumes patches are flattened from a grid num_patches = patch_embed.shape[0] # For now, assume square grid or use the stored dimensions # We'll get actual h, w from the caller height = width = int(num_patches**0.5) if height * width != num_patches: # Non-square: will be handled by caller passing dimensions height = width = int(num_patches**0.5) mesh = torch.meshgrid( torch.arange(height, device=patch_embed.device), torch.arange(width, device=patch_embed.device), indexing="ij", ) h_grid, w_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) ids = h_grid * max_patches_per_side + w_grid positions.append(ids[:, 0]) return torch.cat(positions) class Apriel2VisionEncoder(nn.Module): """Vision encoder with embeddings, transformer blocks, and adapter. Uses Pixtral-style processing: concatenates all image patches into one sequence. Computes position_ids for 2D rotary embeddings and sequence_lengths for image isolation - these are passed to encoder blocks. Mixer-specific handling (rotary cos/sin, cu_seqlens) is delegated to each mixer's preprocess() method. """ def __init__(self, vision_encoder_config: dict, text_config: Apriel2Config): super().__init__() self.hidden_size = vision_encoder_config["hidden_size"] # Build embeddings layer embeddings_config = vision_encoder_config["embeddings"] self.embeddings = Apriel2Embeddings(self.hidden_size, embeddings_config) # Store patch size for computing patch grid dimensions self.patch_size = embeddings_config["patch_height"] # Get max_image_size for 2D position encoding (vision encoder owns this) # Priority: encoder-level config > rotary config in any attention block > default self.max_image_size = self._get_max_image_size(vision_encoder_config) self.max_patches_per_side = self.max_image_size // self.patch_size # Build vision transformer encoder using shared BlockSequence abstraction encoder_config = vision_encoder_config.get("encoder", {}) # Get norm epsilon from text config's head.normalization.epsilon norm_epsilon = text_config.head["normalization"]["epsilon"] # Create a minimal config for vision blocks (hierarchical structure) vision_block_config = Apriel2TextConfig( hidden_size=self.hidden_size, embeddings={"max_position_embeddings": 1024}, # Large enough for typical vision use cases head={"normalization": {"type": "rms_norm", "epsilon": norm_epsilon}}, _attn_implementation=getattr(text_config, "_attn_implementation", "eager"), ) # Vision encoder block sequence - supports any mixer type self.encoder = Apriel2BlockSequence( sequence_config=encoder_config, hidden_size=self.hidden_size, max_position_embeddings=1024, config=vision_block_config, ) # Build adapter/projector adapter_config = vision_encoder_config.get("adapter", {}) self.adapter = self._build_adapter(adapter_config, text_config.hidden_size) def _build_adapter(self, adapter_config: dict, text_hidden_size: int) -> nn.Module: """Build adapter/projector from config dict.""" adapter_type = adapter_config.get("type", "mlp") if adapter_type == "mlp": # 2-layer MLP projector (mirrors Fast-LLM's adapter) intermediate_size = adapter_config.get("intermediate_size", text_hidden_size) activation = adapter_config.get("activation", "gelu") return Apriel2MultiModalProjector( vision_hidden_size=self.hidden_size, text_hidden_size=text_hidden_size, intermediate_size=intermediate_size, activation=activation, ) else: raise ValueError(f"Unknown adapter type: {adapter_type}") def _get_max_image_size(self, config: dict) -> int: """Extract max_image_size from config with fallback chain. This is a vision encoder concern - determines 2D position encoding grid size. Priority: 1. Encoder-level config: config["max_image_size"] 2. From any attention block's rotary config (for backward compatibility) 3. Default: 4096 (supports up to ~292x292 patches with patch_size=14) """ # Priority 1: encoder-level config if "max_image_size" in config: return config["max_image_size"] # Priority 2: search through blocks for rotary config encoder_config = config.get("encoder", {}) for block_config in self._iter_block_configs(encoder_config): mixer_config = block_config.get("mixer", {}) rotary_config = mixer_config.get("rotary", {}) if "max_image_size" in rotary_config: return rotary_config["max_image_size"] # Default fallback return 4096 def _iter_block_configs(self, encoder_config: dict): """Iterate over all block configs in encoder (handles fixed/pattern types).""" seq_type = encoder_config.get("type", "fixed") if seq_type == "fixed": block_config = encoder_config.get("block", {}) if block_config: yield block_config elif seq_type == "pattern": blocks_config = encoder_config.get("blocks", {}) for block_config in blocks_config.values(): yield block_config def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Process images through vision encoder using Pixtral-style concatenation. All image patches are concatenated into ONE sequence. Vision encoder computes: - position_ids: 2D position encoding (row * max_patches_per_side + col) - sequence_lengths: patches per image (for image isolation) These are passed to encoder blocks. Mixer-specific handling (rotary cos/sin, cu_seqlens/masks) is delegated to each mixer's preprocess() method. Args: pixel_values: [batch, channels, height, width] - batch of images Returns: image_features: [batch, num_patches, text_hidden_size] """ batch_size = pixel_values.shape[0] _, _, img_height, img_width = pixel_values.shape height_patches = img_height // self.patch_size width_patches = img_width // self.patch_size num_patches_per_image = height_patches * width_patches # Process each image through embeddings independently, then concatenate # This mirrors Pixtral's approach of processing conv independently patch_embeds_list = [] for i in range(batch_size): # [1, channels, H, W] -> [1, num_patches, hidden] embed = self.embeddings(pixel_values[i : i + 1]) # [num_patches, hidden] patch_embeds_list.append(embed.squeeze(0)) # Concatenate all patches into one sequence: [1, total_patches, hidden] hidden_states = torch.cat(patch_embeds_list, dim=0).unsqueeze(0) # Compute position_ids for 2D rotary: position_id = row * max_patches_per_side + col # Vision encoder owns 2D position encoding - attention just uses position_ids positions = [] for _ in range(batch_size): mesh = torch.meshgrid( torch.arange(height_patches, device=hidden_states.device), torch.arange(width_patches, device=hidden_states.device), indexing="ij", ) h_grid, w_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) ids = h_grid * self.max_patches_per_side + w_grid positions.append(ids[:, 0]) position_ids = torch.cat(positions).unsqueeze(0) # [1, total_patches] # Sequence lengths: patches per image (for image isolation in attention) sequence_lengths = [num_patches_per_image] * batch_size # Forward through vision encoder block sequence hidden_states, _, _ = self.encoder( hidden_states, attention_mask=None, # Attention computes masks from sequence_lengths if needed position_ids=position_ids, sequence_lengths=sequence_lengths, past_key_values=None, output_attentions=False, output_hidden_states=False, use_cache=False, cache_position=None, ) # Adapter/projector: [1, total_patches, vision_hidden] -> [1, total_patches, text_hidden] image_features = self.adapter(hidden_states) # Reshape back to [batch, num_patches, text_hidden] image_features = image_features.squeeze(0).view(batch_size, num_patches_per_image, -1) return image_features class SimpleMLP(nn.Module): """Non-gated MLP: up_proj -> activation -> down_proj. Supports per-layer bias configuration mirroring Fast-LLM: - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming) - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming) """ def __init__( self, hidden_size: int, intermediate_size: int, activation: str = "silu", layer_1_bias: bool = False, layer_2_bias: bool = False, ): super().__init__() from transformers.activations import ACT2FN self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias) self.act_fn = ACT2FN[activation] def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) class Apriel2MultiModalProjector(nn.Module): """Projects vision features to text embedding space (2-layer MLP).""" def __init__( self, vision_hidden_size: int, text_hidden_size: int, intermediate_size: Optional[int] = None, activation: str = "gelu", ): super().__init__() from transformers.activations import ACT2FN if intermediate_size is None: intermediate_size = text_hidden_size self.linear_1 = nn.Linear(vision_hidden_size, intermediate_size, bias=True) self.act = ACT2FN[activation] self.linear_2 = nn.Linear(intermediate_size, text_hidden_size, bias=True) def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class Apriel2Model(Apriel2TextModel): """ Apriel2 multimodal base model (vision + text, without LM head). Inherits from Apriel2TextModel (which provides embed_tokens, decoder, norm) and adds vision_encoder. This mirrors Fast-LLM's VisionMultiModalModel(LanguageModel) inheritance pattern for trivial weight conversion. """ config_class = Apriel2Config def __init__(self, config: Apriel2Config): super().__init__(config) # Add vision encoder (text components inherited from Apriel2TextModel) if config.vision_encoder is not None: self.vision_encoder = Apriel2VisionEncoder(config.vision_encoder, config) else: self.vision_encoder = None # Re-run post_init to handle any vision encoder initialization self.post_init() def get_image_features(self, pixel_values, image_sizes=None): """Extract and project image features. Args: pixel_values: [num_images, channels, height, width] - batch of images (possibly padded) image_sizes: Optional[num_images, 2] - actual (height, width) of each image for cropping Returns: image_features: [num_images, num_patches, hidden_size] or concatenated features """ if self.vision_encoder is None: raise ValueError("Cannot extract image features: vision_encoder is None") if image_sizes is None: # No cropping needed - process as batch return self.vision_encoder(pixel_values) # Get patch size from embeddings layer to determine minimum valid image size patch_height = self.vision_encoder.embeddings.patch_embeddings.kernel_size[0] patch_width = self.vision_encoder.embeddings.patch_embeddings.kernel_size[1] # Process each image individually with its actual size all_features = [] for i, (image, (height, width)) in enumerate(zip(pixel_values, image_sizes)): height, width = int(height), int(width) # Skip images that are too small to produce any patches if height < patch_height or width < patch_width: continue # Crop to actual image size cropped = image[:, :height, :width] # Process single image - add batch dim features = self.vision_encoder(cropped.unsqueeze(0)) # Remove batch dim and add to list all_features.append(features.squeeze(0)) if not all_features: # No valid images - return empty tensor return torch.zeros(0, 0, self.config.hidden_size, device=pixel_values.device) # Concatenate all features along patch dimension return torch.cat(all_features, dim=0).unsqueeze(0) # [1, total_patches, hidden] def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Apriel2Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[tuple, BaseModelOutputWithPast]: # If pixel_values provided, we need to merge vision and text embeddings if pixel_values is not None and input_ids is not None: # Encode and project images (with optional cropping based on image_sizes) image_features = self.get_image_features(pixel_values, image_sizes) # Get text embeddings (use inherited embed_tokens) inputs_embeds = self.embed_tokens(input_ids) # Merge image features into text embeddings image_token_index = self.config.image_token_index # Create mask for image token positions: [batch, seq_len] special_image_mask = input_ids == image_token_index # Validate token count matches feature count num_image_tokens = special_image_mask.sum().item() num_image_features = image_features.shape[0] * image_features.shape[1] if num_image_tokens != num_image_features: raise ValueError( f"Image features and image tokens do not match: " f"got {num_image_tokens} image tokens but {num_image_features} image features " f"(shape: {image_features.shape})" ) # Expand mask to match embedding dimension: [batch, seq_len, hidden_size] special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds) # Flatten image features to match the number of True values in mask image_features = image_features.view(-1, image_features.shape[-1]) # Use masked_scatter for efficient vectorized merge inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Clear input_ids since we're using inputs_embeds input_ids = None # Forward through inherited text model components return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) class Apriel2ForConditionalGeneration(Apriel2PreTrainedModel, GenerationMixin): """ Apriel2 multimodal model with language modeling head (vision + text). Inherits from Apriel2PreTrainedModel to get proper cache handling. Uses Apriel2Model (which inherits from Apriel2TextModel) for the base model. """ config_class = Apriel2Config _tied_weights_keys = [] # No weight tying by default, but can be configured def __init__(self, config: Apriel2Config): super().__init__(config) self.model = Apriel2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Handle weight tying if configured if config.tie_word_embeddings: self._tied_weights_keys = ["lm_head.weight"] self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_image_features(self, pixel_values): """Extract and project image features.""" return self.model.get_image_features(pixel_values) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Apriel2Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Forward through model outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state if return_dict else outputs[0] # Compute logits slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] if attention_mask is not None: # Use the input attention mask to shift the logits and labels # Crop attention mask in case it is longer (e.g., in PrefixTuning with peft) shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) shift_logits = shift_logits[shift_attention_mask != 0].contiguous() shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: shift_logits = shift_logits.contiguous() shift_labels = shift_labels.contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() flat_logits = shift_logits.view(-1, self.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) if not return_dict: output = (logits,) + (outputs[1:] if return_dict else outputs[1:]) return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values if return_dict else outputs[1], hidden_states=outputs.hidden_states if return_dict else None, attentions=outputs.attentions if return_dict else None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, cache_position=None, position_ids=None, pixel_values=None, attention_mask=None, use_cache=True, logits_to_keep=None, **kwargs, ): """Prepare inputs for generation, handling multimodal inputs correctly.""" # Overwritten -- custom handling for pixel_values during cached generation model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs, ) # If we're in cached decoding stage, pixel_values should be None because input ids do not contain # special image tokens anymore. Otherwise pixel_values should be passed to model. # NOTE: use_cache=False always needs pixel_values if cache_position is not None and cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values return model_inputs