Make it compatible with transformers 5.3.0
#91
by Kaixuanliu - opened
- modeling_phi4mm.py +21 -10
- speech_conformer_encoder.py +3 -2
modeling_phi4mm.py
CHANGED
|
@@ -26,7 +26,7 @@ from torch import nn
|
|
| 26 |
from torch.nn import CrossEntropyLoss
|
| 27 |
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
-
from transformers.cache_utils import Cache, DynamicCache,
|
| 30 |
from transformers.generation import GenerationMixin
|
| 31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
@@ -41,7 +41,7 @@ from transformers.utils import (
|
|
| 41 |
add_code_sample_docstrings,
|
| 42 |
add_start_docstrings,
|
| 43 |
add_start_docstrings_to_model_forward,
|
| 44 |
-
|
| 45 |
logging,
|
| 46 |
replace_return_docstrings,
|
| 47 |
)
|
|
@@ -1134,7 +1134,7 @@ class Phi4MMAttention(nn.Module):
|
|
| 1134 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 1135 |
"with a layer index."
|
| 1136 |
)
|
| 1137 |
-
kv_seq_len += past_key_value.
|
| 1138 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 1139 |
|
| 1140 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
@@ -1190,7 +1190,7 @@ class Phi4MMFlashAttention2(Phi4MMAttention):
|
|
| 1190 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 1191 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 1192 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 1193 |
-
self._flash_attn_uses_top_left_mask = not
|
| 1194 |
|
| 1195 |
def forward(
|
| 1196 |
self,
|
|
@@ -1229,7 +1229,7 @@ class Phi4MMFlashAttention2(Phi4MMAttention):
|
|
| 1229 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 1230 |
"with a layer index."
|
| 1231 |
)
|
| 1232 |
-
kv_seq_len += past_key_value.
|
| 1233 |
|
| 1234 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 1235 |
rotary_seq_len = (
|
|
@@ -1351,7 +1351,7 @@ class Phi4MMSdpaAttention(Phi4MMAttention):
|
|
| 1351 |
|
| 1352 |
kv_seq_len = key_states.shape[-2]
|
| 1353 |
if past_key_value is not None:
|
| 1354 |
-
kv_seq_len += past_key_value.
|
| 1355 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 1356 |
|
| 1357 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
@@ -1399,6 +1399,7 @@ class Phi4MMSdpaAttention(Phi4MMAttention):
|
|
| 1399 |
PHI4MM_ATTENTION_CLASSES = {
|
| 1400 |
"eager": Phi4MMAttention,
|
| 1401 |
"flash_attention_2": Phi4MMFlashAttention2,
|
|
|
|
| 1402 |
"sdpa": Phi4MMSdpaAttention,
|
| 1403 |
}
|
| 1404 |
|
|
@@ -1511,6 +1512,7 @@ class Phi4MMPreTrainedModel(PreTrainedModel):
|
|
| 1511 |
supports_gradient_checkpointing = True
|
| 1512 |
_no_split_modules = ["Phi4MMDecoderLayer"]
|
| 1513 |
_skip_keys_device_placement = "past_key_values"
|
|
|
|
| 1514 |
_supports_flash_attn_2 = True
|
| 1515 |
_supports_sdpa = True
|
| 1516 |
_supports_cache_class = True
|
|
@@ -1807,7 +1809,7 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
|
|
| 1807 |
# to infer the attention mask.
|
| 1808 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1809 |
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1810 |
-
using_sliding_window_cache =
|
| 1811 |
|
| 1812 |
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1813 |
if (
|
|
@@ -1913,7 +1915,7 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
|
|
| 1913 |
if config.sliding_window is not None:
|
| 1914 |
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
| 1915 |
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
| 1916 |
-
if
|
| 1917 |
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
| 1918 |
cache_position.reshape(-1, 1) - config.sliding_window
|
| 1919 |
)
|
|
@@ -1934,7 +1936,7 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
|
|
| 1934 |
|
| 1935 |
|
| 1936 |
class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
| 1937 |
-
_tied_weights_keys =
|
| 1938 |
|
| 1939 |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi
|
| 1940 |
def __init__(self, config):
|
|
@@ -1949,6 +1951,12 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
|
| 1949 |
# LoRA related settings
|
| 1950 |
assert getattr(config, "vision_lora", None) is not None
|
| 1951 |
from peft import LoraConfig, get_peft_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1952 |
vision_lora_config = LoraConfig(
|
| 1953 |
r=config.vision_lora['r'],
|
| 1954 |
lora_alpha=config.vision_lora['lora_alpha'],
|
|
@@ -2134,7 +2142,10 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
|
| 2134 |
|
| 2135 |
hidden_states = outputs[0]
|
| 2136 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 2137 |
-
|
|
|
|
|
|
|
|
|
|
| 2138 |
|
| 2139 |
loss = None
|
| 2140 |
if labels is not None:
|
|
|
|
| 26 |
from torch.nn import CrossEntropyLoss
|
| 27 |
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 30 |
from transformers.generation import GenerationMixin
|
| 31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
|
|
| 41 |
add_code_sample_docstrings,
|
| 42 |
add_start_docstrings,
|
| 43 |
add_start_docstrings_to_model_forward,
|
| 44 |
+
is_flash_attn_greater_or_equal,
|
| 45 |
logging,
|
| 46 |
replace_return_docstrings,
|
| 47 |
)
|
|
|
|
| 1134 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 1135 |
"with a layer index."
|
| 1136 |
)
|
| 1137 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
| 1138 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 1139 |
|
| 1140 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 1190 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 1191 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 1192 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 1193 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.10")
|
| 1194 |
|
| 1195 |
def forward(
|
| 1196 |
self,
|
|
|
|
| 1229 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 1230 |
"with a layer index."
|
| 1231 |
)
|
| 1232 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
| 1233 |
|
| 1234 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 1235 |
rotary_seq_len = (
|
|
|
|
| 1351 |
|
| 1352 |
kv_seq_len = key_states.shape[-2]
|
| 1353 |
if past_key_value is not None:
|
| 1354 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
| 1355 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 1356 |
|
| 1357 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 1399 |
PHI4MM_ATTENTION_CLASSES = {
|
| 1400 |
"eager": Phi4MMAttention,
|
| 1401 |
"flash_attention_2": Phi4MMFlashAttention2,
|
| 1402 |
+
"kernels-community/flash-attn2": Phi4MMFlashAttention2,
|
| 1403 |
"sdpa": Phi4MMSdpaAttention,
|
| 1404 |
}
|
| 1405 |
|
|
|
|
| 1512 |
supports_gradient_checkpointing = True
|
| 1513 |
_no_split_modules = ["Phi4MMDecoderLayer"]
|
| 1514 |
_skip_keys_device_placement = "past_key_values"
|
| 1515 |
+
_supports_flash_attn = True
|
| 1516 |
_supports_flash_attn_2 = True
|
| 1517 |
_supports_sdpa = True
|
| 1518 |
_supports_cache_class = True
|
|
|
|
| 1809 |
# to infer the attention mask.
|
| 1810 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1811 |
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1812 |
+
using_sliding_window_cache = False # SlidingWindowCache removed in newer transformers
|
| 1813 |
|
| 1814 |
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1815 |
if (
|
|
|
|
| 1915 |
if config.sliding_window is not None:
|
| 1916 |
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
| 1917 |
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
| 1918 |
+
if sequence_length > target_length: # SlidingWindowCache removed
|
| 1919 |
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
| 1920 |
cache_position.reshape(-1, 1) - config.sliding_window
|
| 1921 |
)
|
|
|
|
| 1936 |
|
| 1937 |
|
| 1938 |
class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
| 1939 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 1940 |
|
| 1941 |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi
|
| 1942 |
def __init__(self, config):
|
|
|
|
| 1951 |
# LoRA related settings
|
| 1952 |
assert getattr(config, "vision_lora", None) is not None
|
| 1953 |
from peft import LoraConfig, get_peft_model
|
| 1954 |
+
|
| 1955 |
+
# Add a placeholder prepare_inputs_for_generation to satisfy PEFT's requirements
|
| 1956 |
+
# The actual method is defined on Phi4MMForCausalLM
|
| 1957 |
+
if not hasattr(self.model, 'prepare_inputs_for_generation'):
|
| 1958 |
+
self.model.prepare_inputs_for_generation = lambda *args, **kwargs: None
|
| 1959 |
+
|
| 1960 |
vision_lora_config = LoraConfig(
|
| 1961 |
r=config.vision_lora['r'],
|
| 1962 |
lora_alpha=config.vision_lora['lora_alpha'],
|
|
|
|
| 2142 |
|
| 2143 |
hidden_states = outputs[0]
|
| 2144 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 2145 |
+
if num_logits_to_keep is None or num_logits_to_keep == 0:
|
| 2146 |
+
logits = self.lm_head(hidden_states)
|
| 2147 |
+
else:
|
| 2148 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 2149 |
|
| 2150 |
loss = None
|
| 2151 |
if labels is not None:
|
speech_conformer_encoder.py
CHANGED
|
@@ -1423,7 +1423,8 @@ class NemoConvSubsampling(torch.nn.Module):
|
|
| 1423 |
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
| 1424 |
|
| 1425 |
if subsampling in ["dw_striding", "striding"]:
|
| 1426 |
-
|
|
|
|
| 1427 |
out_length = calc_length(
|
| 1428 |
lengths=in_length,
|
| 1429 |
all_paddings=self._left_padding + self._right_padding,
|
|
@@ -1432,7 +1433,7 @@ class NemoConvSubsampling(torch.nn.Module):
|
|
| 1432 |
ceil_mode=self._ceil_mode,
|
| 1433 |
repeat_num=self._sampling_num,
|
| 1434 |
)
|
| 1435 |
-
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
|
| 1436 |
self.conv2d_subsampling = True
|
| 1437 |
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
| 1438 |
self.out = None
|
|
|
|
| 1423 |
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
| 1424 |
|
| 1425 |
if subsampling in ["dw_striding", "striding"]:
|
| 1426 |
+
# Force CPU tensor to avoid meta tensor issues with device_map
|
| 1427 |
+
in_length = torch.tensor(feat_in, dtype=torch.float, device='cpu')
|
| 1428 |
out_length = calc_length(
|
| 1429 |
lengths=in_length,
|
| 1430 |
all_paddings=self._left_padding + self._right_padding,
|
|
|
|
| 1433 |
ceil_mode=self._ceil_mode,
|
| 1434 |
repeat_num=self._sampling_num,
|
| 1435 |
)
|
| 1436 |
+
self.out = torch.nn.Linear(conv_channels * int(out_length.item()), feat_out)
|
| 1437 |
self.conv2d_subsampling = True
|
| 1438 |
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
| 1439 |
self.out = None
|