| """ PyTorch Wav2Vec2-Ebranchformer model.""" |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from transformers.activations import ACT2FN |
| from transformers.models.wav2vec2.modeling_wav2vec2 import ( |
| Wav2Vec2Config, |
| Wav2Vec2ForCTC, |
| Wav2Vec2ForPreTraining, |
| ) |
| from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| Wav2Vec2ConformerConfig, |
| Wav2Vec2ConformerEncoder, |
| ) |
| from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| Wav2Vec2ConformerFeedForward as Wav2Vec2EBranchformerFeedForward, |
| ) |
| from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| Wav2Vec2ConformerModel, |
| ) |
| from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| Wav2Vec2ConformerSelfAttention as Wav2Vec2EBranchformerSelfAttention, |
| ) |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class Wav2Vec2EBranchformerConfig(Wav2Vec2ConformerConfig, Wav2Vec2Config): |
| """Config for EBranhformer model extending conformer.""" |
|
|
| model_type = "wav2vec2-ebranchformer" |
|
|
| def __init__( |
| self, |
| ebranchformer_conv_dropout=0.1, |
| csgu_activation="identity", |
| csgu_kernel_size=31, |
| csgu_use_linear_after_conv=False, |
| merge_conv_kernel=31, |
| use_macaron_ff=True, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| |
| self.csgu_kernel_size = csgu_kernel_size |
| self.csgu_activation = csgu_activation |
| self.csgu_conv_dropout = ebranchformer_conv_dropout |
| self.csgu_use_linear_after_conv = csgu_use_linear_after_conv |
| self.merge_conv_kernel = merge_conv_kernel |
| self.use_macaron_ff = use_macaron_ff |
|
|
|
|
| class ConvolutionalSpatialGatingUnit(torch.nn.Module): |
| """Convolutional Spatial Gating Unit (CSGU).""" |
|
|
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__() |
|
|
| n_channels = config.intermediate_size // 2 |
| self.norm = torch.nn.LayerNorm(n_channels) |
| self.conv = torch.nn.Conv1d( |
| n_channels, |
| n_channels, |
| config.csgu_kernel_size, |
| 1, |
| (config.csgu_kernel_size - 1) // 2, |
| groups=n_channels, |
| ) |
| if config.csgu_use_linear_after_conv: |
| self.linear = torch.nn.Linear(n_channels, n_channels) |
| else: |
| self.linear = None |
|
|
| if config.csgu_activation == "identity": |
| self.act = torch.nn.Identity() |
| else: |
| self.act = ACT2FN[config.csgu_activation] |
|
|
| self.dropout = torch.nn.Dropout(config.csgu_conv_dropout) |
|
|
| def forward(self, hidden_states: torch.FloatTensor): |
| """Forward method |
| |
| Args: |
| hidden_states (torch.Tensor): (N, T, D) |
| |
| Returns: |
| out (torch.Tensor): (N, T, D/2) |
| """ |
|
|
| x_r, x_g = hidden_states.chunk(2, dim=-1) |
|
|
| x_g = self.norm(x_g) |
| x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) |
| if self.linear is not None: |
| x_g = self.linear(x_g) |
|
|
| x_g = self.act(x_g) |
| hidden_states = x_r * x_g |
| hidden_states = self.dropout(hidden_states) |
| return hidden_states |
|
|
|
|
| class ConvolutionalGatingMLP(torch.nn.Module): |
| """Convolutional Gating MLP (cgMLP).""" |
|
|
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__() |
| self.channel_proj1 = torch.nn.Sequential( |
| torch.nn.Linear(config.hidden_size, config.intermediate_size), torch.nn.GELU() |
| ) |
| self.csgu = ConvolutionalSpatialGatingUnit(config) |
| self.channel_proj2 = torch.nn.Linear(config.intermediate_size // 2, config.hidden_size) |
|
|
| def forward(self, hidden_states: torch.FloatTensor): |
| hidden_states = self.channel_proj1(hidden_states) |
| hidden_states = self.csgu(hidden_states) |
| hidden_states = self.channel_proj2(hidden_states) |
| return hidden_states |
|
|
|
|
| class Wav2Vec2EBranchformerEncoderLayer(nn.Module): |
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__() |
| embed_dim = config.hidden_size |
| dropout = config.attention_dropout |
|
|
| |
| if config.use_macaron_ff: |
| self.ff1 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config)) |
|
|
| |
| self.self_attn_layer_norm = nn.LayerNorm(embed_dim) |
| self.self_attn_dropout = torch.nn.Dropout(dropout) |
| self.self_attn = Wav2Vec2EBranchformerSelfAttention(config) |
|
|
| |
| self.cgMLP = ConvolutionalGatingMLP(config) |
| self.cgMLP_layer_norm = nn.LayerNorm(config.hidden_size) |
| self.cgMLP_dropout = torch.nn.Dropout(dropout) |
|
|
| |
| self.final_dropout = torch.nn.Dropout(dropout) |
| self.merge_proj = torch.nn.Linear(embed_dim + embed_dim, embed_dim) |
| self.depthwise_conv_fusion = torch.nn.Conv1d( |
| embed_dim + embed_dim, |
| embed_dim + embed_dim, |
| kernel_size=config.merge_conv_kernel, |
| stride=1, |
| padding=(config.merge_conv_kernel - 1) // 2, |
| groups=embed_dim + embed_dim, |
| bias=True, |
| ) |
| self.final_layer_norm = nn.LayerNorm(embed_dim) |
|
|
| |
| if config.use_macaron_ff: |
| self.ff2 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| relative_position_embeddings: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ): |
| |
| if self.ff1: |
| residual = hidden_states |
| hidden_states = residual + 0.5 * self.ff1(hidden_states) |
|
|
| |
| residual = hidden_states |
| global_branch = hidden_states |
| local_branch = hidden_states |
|
|
| |
| global_branch = self.self_attn_layer_norm(global_branch) |
| global_branch, attn_weigts = self.self_attn( |
| hidden_states=global_branch, |
| attention_mask=attention_mask, |
| relative_position_embeddings=relative_position_embeddings, |
| output_attentions=output_attentions, |
| ) |
| global_branch = self.self_attn_dropout(global_branch) |
|
|
| |
| local_branch = self.cgMLP_layer_norm(local_branch) |
| local_branch = self.cgMLP(local_branch) |
|
|
| |
| |
| hidden_states = torch.cat([global_branch, local_branch], dim=-1) |
| merge_residual = hidden_states |
| |
| hidden_states = merge_residual + self.depthwise_conv_fusion(hidden_states.transpose(1, 2)).transpose(1, 2) |
| |
| hidden_states = self.final_dropout(self.merge_proj(hidden_states)) |
|
|
| |
| hidden_states = residual + hidden_states |
|
|
| |
| if self.ff2: |
| residual = hidden_states |
| hidden_states = residual + 0.5 * self.ff2(hidden_states) |
|
|
| |
| hidden_states = self.final_layer_norm(hidden_states) |
| return hidden_states, attn_weigts |
|
|
|
|
| class Wav2Vec2EBranchformerEncoder(Wav2Vec2ConformerEncoder): |
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__(config) |
| self.layers = nn.ModuleList( |
| [Wav2Vec2EBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)] |
| ) |
| self.pos_conv_embed = None |
|
|
|
|
| class Wav2Vec2EBranchformerModel(Wav2Vec2ConformerModel): |
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__(config) |
| self.encoder = Wav2Vec2EBranchformerEncoder(config) |
|
|
| |
| self.post_init() |
|
|
|
|
| class Wav2Vec2EBranchformerForPreTraining(Wav2Vec2ForPreTraining): |
| config_class = Wav2Vec2EBranchformerConfig |
| base_model_prefix = "wav2vec2" |
|
|
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__(config) |
| self.wav2vec2 = Wav2Vec2EBranchformerModel(config) |
| self.post_init() |
|
|
|
|
| class Wav2Vec2EBranchformerForCTC(Wav2Vec2ForCTC): |
| config_class = Wav2Vec2EBranchformerConfig |
| base_model_prefix = "wav2vec2" |
|
|
| def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| super().__init__(config) |
| self.wav2vec2 = Wav2Vec2EBranchformerModel(config) |
| self.post_init() |
|
|