Taykhoom commited on
Commit
adac0f5
·
verified ·
1 Parent(s): a8c6b7e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_helix_mrna.py +24 -4
README.md CHANGED
@@ -108,6 +108,7 @@ This repository contains modified versions of Helical code.
108
  Modifications include:
109
  - Removal of reliance on helical package
110
  - Removal of some ease-of-use embedding generation code (to standardize usage) and other checks (see original repository for more details)
 
111
 
112
  Not all of the original functionality may be preserved. These changes were made to better integrate with the mRNABench framework which focuses on embedding generation for mRNA sequences. Most of the required code was directly copied from the original Helical repository with minimal changes, so please refer to the original repository for full details on the implementation.
113
 
 
108
  Modifications include:
109
  - Removal of reliance on helical package
110
  - Removal of some ease-of-use embedding generation code (to standardize usage) and other checks (see original repository for more details)
111
+ - Standardized return of attention maps and output embeddings to be in line with Hugging Face convention (i.e. None returned for all Mamba blocks attention weights, and input_embeddings returned when output_hidden_states is True)
112
 
113
  Not all of the original functionality may be preserved. These changes were made to better integrate with the mRNABench framework which focuses on embedding generation for mRNA sequences. Most of the required code was directly copied from the original Helical repository with minimal changes, so please refer to the original repository for full details on the implementation.
114
 
modeling_helix_mrna.py CHANGED
@@ -1431,6 +1431,8 @@ class HelixmRNAOutput(ModelOutput):
1431
  avoid providing the old `input_ids`.
1432
 
1433
  Includes both the State space model state matrices after the selective scan, and the Convolutional states
 
 
1434
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1435
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1436
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
@@ -1440,6 +1442,7 @@ class HelixmRNAOutput(ModelOutput):
1440
 
1441
  last_hidden_state: Optional[torch.FloatTensor] = None
1442
  cache_params: Optional[Mamba2Cache] = None
 
1443
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1444
 
1445
 
@@ -1579,6 +1582,13 @@ class HelixmRNAModel(HelixmRNAPreTrainedModel):
1579
 
1580
  all_hidden_states = () if output_hidden_states else None
1581
  all_self_attns = () if output_attentions else None
 
 
 
 
 
 
 
1582
  for helix_block in self.layers:
1583
 
1584
  layer_mask = (
@@ -1610,15 +1620,22 @@ class HelixmRNAModel(HelixmRNAPreTrainedModel):
1610
  if output_hidden_states:
1611
  all_hidden_states += (layer_outputs[0],)
1612
 
 
 
 
 
 
1613
  hidden_states = self.norm_f(layer_outputs[0])
1614
 
1615
  if output_hidden_states:
1616
  all_hidden_states = all_hidden_states + (hidden_states,)
1617
 
1618
- if output_attentions:
1619
- if layer_outputs[1] is not None:
1620
- # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1621
- all_self_attns += (layer_outputs[1],)
 
 
1622
 
1623
  if use_cache:
1624
  cache_params.seqlen_offset += inputs_embeds.shape[1]
@@ -1630,11 +1647,14 @@ class HelixmRNAModel(HelixmRNAPreTrainedModel):
1630
  if v is not None
1631
  )
1632
 
 
1633
  return HelixmRNAOutput(
1634
  last_hidden_state=hidden_states,
1635
  cache_params=cache_params if use_cache else None,
 
1636
  hidden_states=all_hidden_states,
1637
  )
 
1638
 
1639
  def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1640
  if self.config._attn_implementation == "flash_attention_2":
 
1431
  avoid providing the old `input_ids`.
1432
 
1433
  Includes both the State space model state matrices after the selective scan, and the Convolutional states
1434
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1435
+ Attention weights of all attention layers. Each entry is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
1436
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1437
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1438
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 
1442
 
1443
  last_hidden_state: Optional[torch.FloatTensor] = None
1444
  cache_params: Optional[Mamba2Cache] = None
1445
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1446
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1447
 
1448
 
 
1582
 
1583
  all_hidden_states = () if output_hidden_states else None
1584
  all_self_attns = () if output_attentions else None
1585
+
1586
+ ####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
1587
+ if output_hidden_states:
1588
+ all_hidden_states += (inputs_embeds,) # index 0 = embedding, matching HF convention
1589
+
1590
+ ####### END OF CHANGE #######
1591
+
1592
  for helix_block in self.layers:
1593
 
1594
  layer_mask = (
 
1620
  if output_hidden_states:
1621
  all_hidden_states += (layer_outputs[0],)
1622
 
1623
+ ####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
1624
+ if output_attentions:
1625
+ all_self_attns += (layer_outputs[1] if len(layer_outputs) > 1 else None,)
1626
+ ####### END OF CHANGE #######
1627
+
1628
  hidden_states = self.norm_f(layer_outputs[0])
1629
 
1630
  if output_hidden_states:
1631
  all_hidden_states = all_hidden_states + (hidden_states,)
1632
 
1633
+ ####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
1634
+ # if output_attentions:
1635
+ # if layer_outputs[1] is not None:
1636
+ # # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1637
+ # all_self_attns += (layer_outputs[1],)
1638
+ ####### END OF CHANGE #######
1639
 
1640
  if use_cache:
1641
  cache_params.seqlen_offset += inputs_embeds.shape[1]
 
1647
  if v is not None
1648
  )
1649
 
1650
+ ####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
1651
  return HelixmRNAOutput(
1652
  last_hidden_state=hidden_states,
1653
  cache_params=cache_params if use_cache else None,
1654
+ attentions=all_self_attns,
1655
  hidden_states=all_hidden_states,
1656
  )
1657
+ ####### END OF CHANGE #######
1658
 
1659
  def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1660
  if self.config._attn_implementation == "flash_attention_2":