Make it compatible with transformers 5.3.0

#91
Files changed (2) hide show
  1. modeling_phi4mm.py +21 -10
  2. speech_conformer_encoder.py +3 -2
modeling_phi4mm.py CHANGED
@@ -26,7 +26,7 @@ from torch import nn
26
  from torch.nn import CrossEntropyLoss
27
 
28
  from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
30
  from transformers.generation import GenerationMixin
31
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
@@ -41,7 +41,7 @@ from transformers.utils import (
41
  add_code_sample_docstrings,
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
- is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
47
  )
@@ -1134,7 +1134,7 @@ class Phi4MMAttention(nn.Module):
1134
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1135
  "with a layer index."
1136
  )
1137
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1138
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1139
 
1140
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -1190,7 +1190,7 @@ class Phi4MMFlashAttention2(Phi4MMAttention):
1190
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
1191
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
1192
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1193
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1194
 
1195
  def forward(
1196
  self,
@@ -1229,7 +1229,7 @@ class Phi4MMFlashAttention2(Phi4MMAttention):
1229
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1230
  "with a layer index."
1231
  )
1232
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1233
 
1234
  # Because the input can be padded, the absolute sequence length depends on the max position id.
1235
  rotary_seq_len = (
@@ -1351,7 +1351,7 @@ class Phi4MMSdpaAttention(Phi4MMAttention):
1351
 
1352
  kv_seq_len = key_states.shape[-2]
1353
  if past_key_value is not None:
1354
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1355
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1356
 
1357
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -1399,6 +1399,7 @@ class Phi4MMSdpaAttention(Phi4MMAttention):
1399
  PHI4MM_ATTENTION_CLASSES = {
1400
  "eager": Phi4MMAttention,
1401
  "flash_attention_2": Phi4MMFlashAttention2,
 
1402
  "sdpa": Phi4MMSdpaAttention,
1403
  }
1404
 
@@ -1511,6 +1512,7 @@ class Phi4MMPreTrainedModel(PreTrainedModel):
1511
  supports_gradient_checkpointing = True
1512
  _no_split_modules = ["Phi4MMDecoderLayer"]
1513
  _skip_keys_device_placement = "past_key_values"
 
1514
  _supports_flash_attn_2 = True
1515
  _supports_sdpa = True
1516
  _supports_cache_class = True
@@ -1807,7 +1809,7 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
1807
  # to infer the attention mask.
1808
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1809
  using_static_cache = isinstance(past_key_values, StaticCache)
1810
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
1811
 
1812
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1813
  if (
@@ -1913,7 +1915,7 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
1913
  if config.sliding_window is not None:
1914
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
1915
  # the check is needed to verify is current checkpoint was trained with sliding window or not
1916
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
1917
  sliding_attend_mask = torch.arange(target_length, device=device) <= (
1918
  cache_position.reshape(-1, 1) - config.sliding_window
1919
  )
@@ -1934,7 +1936,7 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
1934
 
1935
 
1936
  class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
1937
- _tied_weights_keys = ["lm_head.weight"]
1938
 
1939
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi
1940
  def __init__(self, config):
@@ -1949,6 +1951,12 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
1949
  # LoRA related settings
1950
  assert getattr(config, "vision_lora", None) is not None
1951
  from peft import LoraConfig, get_peft_model
 
 
 
 
 
 
1952
  vision_lora_config = LoraConfig(
1953
  r=config.vision_lora['r'],
1954
  lora_alpha=config.vision_lora['lora_alpha'],
@@ -2134,7 +2142,10 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
2134
 
2135
  hidden_states = outputs[0]
2136
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2137
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
 
 
 
2138
 
2139
  loss = None
2140
  if labels is not None:
 
26
  from torch.nn import CrossEntropyLoss
27
 
28
  from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.generation import GenerationMixin
31
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
 
41
  add_code_sample_docstrings,
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
+ is_flash_attn_greater_or_equal,
45
  logging,
46
  replace_return_docstrings,
47
  )
 
1134
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1135
  "with a layer index."
1136
  )
1137
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
1138
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1139
 
1140
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1190
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
1191
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
1192
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1193
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.10")
1194
 
1195
  def forward(
1196
  self,
 
1229
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1230
  "with a layer index."
1231
  )
1232
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
1233
 
1234
  # Because the input can be padded, the absolute sequence length depends on the max position id.
1235
  rotary_seq_len = (
 
1351
 
1352
  kv_seq_len = key_states.shape[-2]
1353
  if past_key_value is not None:
1354
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
1355
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1356
 
1357
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1399
  PHI4MM_ATTENTION_CLASSES = {
1400
  "eager": Phi4MMAttention,
1401
  "flash_attention_2": Phi4MMFlashAttention2,
1402
+ "kernels-community/flash-attn2": Phi4MMFlashAttention2,
1403
  "sdpa": Phi4MMSdpaAttention,
1404
  }
1405
 
 
1512
  supports_gradient_checkpointing = True
1513
  _no_split_modules = ["Phi4MMDecoderLayer"]
1514
  _skip_keys_device_placement = "past_key_values"
1515
+ _supports_flash_attn = True
1516
  _supports_flash_attn_2 = True
1517
  _supports_sdpa = True
1518
  _supports_cache_class = True
 
1809
  # to infer the attention mask.
1810
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1811
  using_static_cache = isinstance(past_key_values, StaticCache)
1812
+ using_sliding_window_cache = False # SlidingWindowCache removed in newer transformers
1813
 
1814
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1815
  if (
 
1915
  if config.sliding_window is not None:
1916
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
1917
  # the check is needed to verify is current checkpoint was trained with sliding window or not
1918
+ if sequence_length > target_length: # SlidingWindowCache removed
1919
  sliding_attend_mask = torch.arange(target_length, device=device) <= (
1920
  cache_position.reshape(-1, 1) - config.sliding_window
1921
  )
 
1936
 
1937
 
1938
  class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
1939
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1940
 
1941
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi
1942
  def __init__(self, config):
 
1951
  # LoRA related settings
1952
  assert getattr(config, "vision_lora", None) is not None
1953
  from peft import LoraConfig, get_peft_model
1954
+
1955
+ # Add a placeholder prepare_inputs_for_generation to satisfy PEFT's requirements
1956
+ # The actual method is defined on Phi4MMForCausalLM
1957
+ if not hasattr(self.model, 'prepare_inputs_for_generation'):
1958
+ self.model.prepare_inputs_for_generation = lambda *args, **kwargs: None
1959
+
1960
  vision_lora_config = LoraConfig(
1961
  r=config.vision_lora['r'],
1962
  lora_alpha=config.vision_lora['lora_alpha'],
 
2142
 
2143
  hidden_states = outputs[0]
2144
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2145
+ if num_logits_to_keep is None or num_logits_to_keep == 0:
2146
+ logits = self.lm_head(hidden_states)
2147
+ else:
2148
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
2149
 
2150
  loss = None
2151
  if labels is not None:
speech_conformer_encoder.py CHANGED
@@ -1423,7 +1423,8 @@ class NemoConvSubsampling(torch.nn.Module):
1423
  raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1424
 
1425
  if subsampling in ["dw_striding", "striding"]:
1426
- in_length = torch.tensor(feat_in, dtype=torch.float)
 
1427
  out_length = calc_length(
1428
  lengths=in_length,
1429
  all_paddings=self._left_padding + self._right_padding,
@@ -1432,7 +1433,7 @@ class NemoConvSubsampling(torch.nn.Module):
1432
  ceil_mode=self._ceil_mode,
1433
  repeat_num=self._sampling_num,
1434
  )
1435
- self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
1436
  self.conv2d_subsampling = True
1437
  elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
1438
  self.out = None
 
1423
  raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1424
 
1425
  if subsampling in ["dw_striding", "striding"]:
1426
+ # Force CPU tensor to avoid meta tensor issues with device_map
1427
+ in_length = torch.tensor(feat_in, dtype=torch.float, device='cpu')
1428
  out_length = calc_length(
1429
  lengths=in_length,
1430
  all_paddings=self._left_padding + self._right_padding,
 
1433
  ceil_mode=self._ceil_mode,
1434
  repeat_num=self._sampling_num,
1435
  )
1436
+ self.out = torch.nn.Linear(conv_channels * int(out_length.item()), feat_out)
1437
  self.conv2d_subsampling = True
1438
  elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
1439
  self.out = None