Upload folder using huggingface_hub
Browse files- README.md +1 -0
- modeling_helix_mrna.py +24 -4
README.md
CHANGED
|
@@ -108,6 +108,7 @@ This repository contains modified versions of Helical code.
|
|
| 108 |
Modifications include:
|
| 109 |
- Removal of reliance on helical package
|
| 110 |
- Removal of some ease-of-use embedding generation code (to standardize usage) and other checks (see original repository for more details)
|
|
|
|
| 111 |
|
| 112 |
Not all of the original functionality may be preserved. These changes were made to better integrate with the mRNABench framework which focuses on embedding generation for mRNA sequences. Most of the required code was directly copied from the original Helical repository with minimal changes, so please refer to the original repository for full details on the implementation.
|
| 113 |
|
|
|
|
| 108 |
Modifications include:
|
| 109 |
- Removal of reliance on helical package
|
| 110 |
- Removal of some ease-of-use embedding generation code (to standardize usage) and other checks (see original repository for more details)
|
| 111 |
+
- Standardized return of attention maps and output embeddings to be in line with Hugging Face convention (i.e. None returned for all Mamba blocks attention weights, and input_embeddings returned when output_hidden_states is True)
|
| 112 |
|
| 113 |
Not all of the original functionality may be preserved. These changes were made to better integrate with the mRNABench framework which focuses on embedding generation for mRNA sequences. Most of the required code was directly copied from the original Helical repository with minimal changes, so please refer to the original repository for full details on the implementation.
|
| 114 |
|
modeling_helix_mrna.py
CHANGED
|
@@ -1431,6 +1431,8 @@ class HelixmRNAOutput(ModelOutput):
|
|
| 1431 |
avoid providing the old `input_ids`.
|
| 1432 |
|
| 1433 |
Includes both the State space model state matrices after the selective scan, and the Convolutional states
|
|
|
|
|
|
|
| 1434 |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 1435 |
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 1436 |
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
@@ -1440,6 +1442,7 @@ class HelixmRNAOutput(ModelOutput):
|
|
| 1440 |
|
| 1441 |
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 1442 |
cache_params: Optional[Mamba2Cache] = None
|
|
|
|
| 1443 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 1444 |
|
| 1445 |
|
|
@@ -1579,6 +1582,13 @@ class HelixmRNAModel(HelixmRNAPreTrainedModel):
|
|
| 1579 |
|
| 1580 |
all_hidden_states = () if output_hidden_states else None
|
| 1581 |
all_self_attns = () if output_attentions else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1582 |
for helix_block in self.layers:
|
| 1583 |
|
| 1584 |
layer_mask = (
|
|
@@ -1610,15 +1620,22 @@ class HelixmRNAModel(HelixmRNAPreTrainedModel):
|
|
| 1610 |
if output_hidden_states:
|
| 1611 |
all_hidden_states += (layer_outputs[0],)
|
| 1612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1613 |
hidden_states = self.norm_f(layer_outputs[0])
|
| 1614 |
|
| 1615 |
if output_hidden_states:
|
| 1616 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1617 |
|
| 1618 |
-
|
| 1619 |
-
|
| 1620 |
-
|
| 1621 |
-
|
|
|
|
|
|
|
| 1622 |
|
| 1623 |
if use_cache:
|
| 1624 |
cache_params.seqlen_offset += inputs_embeds.shape[1]
|
|
@@ -1630,11 +1647,14 @@ class HelixmRNAModel(HelixmRNAPreTrainedModel):
|
|
| 1630 |
if v is not None
|
| 1631 |
)
|
| 1632 |
|
|
|
|
| 1633 |
return HelixmRNAOutput(
|
| 1634 |
last_hidden_state=hidden_states,
|
| 1635 |
cache_params=cache_params if use_cache else None,
|
|
|
|
| 1636 |
hidden_states=all_hidden_states,
|
| 1637 |
)
|
|
|
|
| 1638 |
|
| 1639 |
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
| 1640 |
if self.config._attn_implementation == "flash_attention_2":
|
|
|
|
| 1431 |
avoid providing the old `input_ids`.
|
| 1432 |
|
| 1433 |
Includes both the State space model state matrices after the selective scan, and the Convolutional states
|
| 1434 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 1435 |
+
Attention weights of all attention layers. Each entry is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1436 |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 1437 |
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 1438 |
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
|
| 1442 |
|
| 1443 |
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 1444 |
cache_params: Optional[Mamba2Cache] = None
|
| 1445 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 1446 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 1447 |
|
| 1448 |
|
|
|
|
| 1582 |
|
| 1583 |
all_hidden_states = () if output_hidden_states else None
|
| 1584 |
all_self_attns = () if output_attentions else None
|
| 1585 |
+
|
| 1586 |
+
####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
|
| 1587 |
+
if output_hidden_states:
|
| 1588 |
+
all_hidden_states += (inputs_embeds,) # index 0 = embedding, matching HF convention
|
| 1589 |
+
|
| 1590 |
+
####### END OF CHANGE #######
|
| 1591 |
+
|
| 1592 |
for helix_block in self.layers:
|
| 1593 |
|
| 1594 |
layer_mask = (
|
|
|
|
| 1620 |
if output_hidden_states:
|
| 1621 |
all_hidden_states += (layer_outputs[0],)
|
| 1622 |
|
| 1623 |
+
####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
|
| 1624 |
+
if output_attentions:
|
| 1625 |
+
all_self_attns += (layer_outputs[1] if len(layer_outputs) > 1 else None,)
|
| 1626 |
+
####### END OF CHANGE #######
|
| 1627 |
+
|
| 1628 |
hidden_states = self.norm_f(layer_outputs[0])
|
| 1629 |
|
| 1630 |
if output_hidden_states:
|
| 1631 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1632 |
|
| 1633 |
+
####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
|
| 1634 |
+
# if output_attentions:
|
| 1635 |
+
# if layer_outputs[1] is not None:
|
| 1636 |
+
# # append attentions only of attention layers. Mamba layers return `None` as the attention weights
|
| 1637 |
+
# all_self_attns += (layer_outputs[1],)
|
| 1638 |
+
####### END OF CHANGE #######
|
| 1639 |
|
| 1640 |
if use_cache:
|
| 1641 |
cache_params.seqlen_offset += inputs_embeds.shape[1]
|
|
|
|
| 1647 |
if v is not None
|
| 1648 |
)
|
| 1649 |
|
| 1650 |
+
####### CHANGE TO BE IN LINE WITH HF CONVENTION #######
|
| 1651 |
return HelixmRNAOutput(
|
| 1652 |
last_hidden_state=hidden_states,
|
| 1653 |
cache_params=cache_params if use_cache else None,
|
| 1654 |
+
attentions=all_self_attns,
|
| 1655 |
hidden_states=all_hidden_states,
|
| 1656 |
)
|
| 1657 |
+
####### END OF CHANGE #######
|
| 1658 |
|
| 1659 |
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
| 1660 |
if self.config._attn_implementation == "flash_attention_2":
|