| from typing import Callable, Optional, Tuple |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
|
|
| from transformers.cache_utils import Cache |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from transformers.processing_utils import Unpack |
| from transformers.utils import logging |
| from transformers.models.llama.modeling_llama import ( |
| LlamaAttention, |
| LlamaDecoderLayer, |
| LlamaForCausalLM, |
| LlamaForQuestionAnswering, |
| LlamaForSequenceClassification, |
| LlamaForTokenClassification, |
| LlamaMLP, |
| LlamaModel, |
| apply_rotary_pos_emb, |
| eager_attention_forward, |
| ) |
| from .configuration_qwen2 import Qwen2Config |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class Qwen2MLP(LlamaMLP): |
| def __init__(self, config): |
| super().__init__(config) |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
|
|
| class Qwen2Attention(LlamaAttention): |
| def __init__(self, config: Qwen2Config, layer_idx: int): |
| super().__init__(config, layer_idx) |
| self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| sliding_window = None |
| if ( |
| self.config.use_sliding_window |
| and getattr(self.config, "sliding_window", None) is not None |
| and self.layer_idx >= self.config.max_window_layers |
| ): |
| sliding_window = self.config.sliding_window |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
| logger.warning_once( |
| "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
| 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| ) |
| else: |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=sliding_window, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class Qwen2DecoderLayer(LlamaDecoderLayer): |
| def __init__(self, config: Qwen2Config, layer_idx: int): |
| super().__init__() |
| self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) |
| self.mlp = Qwen2MLP(config) |
| if config.sliding_window and config._attn_implementation != "flash_attention_2": |
| logger.warning_once( |
| f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " |
| "unexpected results may be encountered." |
| ) |
|
|
|
|
| class Qwen2Model(LlamaModel): |
| pass |
|
|
|
|
| class Qwen2ForCausalLM(LlamaForCausalLM): |
| pass |
|
|
|
|
| class Qwen2ForSequenceClassification(LlamaForSequenceClassification): |
| pass |
|
|
|
|
| class Qwen2ForTokenClassification(LlamaForTokenClassification): |
| pass |
|
|
|
|
| class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): |
| pass |
|
|