File size: 8,345 Bytes
bdbf0fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | from typing import Optional, Callable
from typing_extensions import Unpack, Tuple
import torch
from torch import nn
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
Qwen3Config,
Qwen3PreTrainedModel,
Qwen3MLP,
GradientCheckpointingLayer,
FlashAttentionKwargs,
rotate_half,
eager_attention_forward,
ALL_ATTENTION_FUNCTIONS,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_len = q.size(-2)
q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :])
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3DFlashAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
target_hidden: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
bsz, q_len = hidden_states.shape[:-1]
ctx_len = target_hidden.shape[1]
q = self.q_proj(hidden_states)
q = q.view(bsz, q_len, -1, self.head_dim)
q = self.q_norm(q).transpose(1, 2)
k_ctx = self.k_proj(target_hidden)
k_noise = self.k_proj(hidden_states)
v_ctx = self.v_proj(target_hidden)
v_noise = self.v_proj(hidden_states)
k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim)
v = torch.cat([v_ctx, v_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim)
k = self.k_norm(k).transpose(1, 2)
v = v.transpose(1, 2)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs)
attn_fn: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attn_fn(
self,
q,
k,
v,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx)
self.mlp = Qwen3MLP(config)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
target_hidden: Optional[torch.Tensor] = None,
hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
target_hidden=target_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class DFlashDraftModel(Qwen3PreTrainedModel):
config_class = Qwen3Config
_no_split_modules = ["Qwen3DFlashDecoderLayer"]
def __init__(self, config) -> None:
super().__init__(config)
self.config = config
self.layers = nn.ModuleList(
[Qwen3DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.target_layer_ids = self.config.dflash_config.get("target_layer_ids", None)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config)
self.fc = nn.Linear(len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False)
self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.block_size = config.block_size
self.mask_token_id = self.config.dflash_config.get("mask_token_id", None)
self.post_init()
def forward(
self,
position_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
noise_embedding: Optional[torch.Tensor] = None,
target_hidden: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: bool = False,
**kwargs,
) -> CausalLMOutputWithPast:
hidden_states = noise_embedding
target_hidden = self.hidden_norm(self.fc(target_hidden))
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
hidden_states = layer(
hidden_states=hidden_states,
target_hidden=target_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
position_embeddings=position_embeddings,
**kwargs,
)
return self.norm(hidden_states) |