| from transformers import Wav2Vec2Config, Wav2Vec2Model |
| from transformers.modeling_outputs import BaseModelOutput |
| import torch |
| import torch.nn.functional as F |
|
|
| def get_mask_from_lengths(lengths, max_len=None): |
| lengths = lengths.to(torch.long) |
| if max_len is None: |
| max_len = torch.max(lengths).item() |
|
|
| ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) |
| mask = ids < lengths.unsqueeze(1).expand(-1, max_len) |
|
|
| return mask |
|
|
|
|
| def linear_interpolation(features, seq_len): |
| features = features.transpose(1, 2) |
| output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') |
| return output_features.transpose(1, 2) |
|
|
| |
| |
| |
| class Wav2Vec2Model(Wav2Vec2Model): |
| def __init__(self, config: Wav2Vec2Config): |
| super().__init__(config) |
|
|
| def forward( |
| self, |
| input_values, |
| seq_len, |
| attention_mask=None, |
| mask_time_indices=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| |
|
|
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| extract_features = self.feature_extractor(input_values) |
| extract_features = extract_features.transpose(1, 2) |
| extract_features = linear_interpolation(extract_features, seq_len=seq_len) |
|
|
| if attention_mask is not None: |
| |
| attention_mask = self._get_feature_vector_attention_mask( |
| extract_features.shape[1], attention_mask, add_adapter=False |
| ) |
|
|
| hidden_states, extract_features = self.feature_projection(extract_features) |
| hidden_states = self._mask_hidden_states( |
| hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask |
| ) |
|
|
| encoder_outputs = self.encoder( |
| hidden_states, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| if self.adapter is not None: |
| hidden_states = self.adapter(hidden_states) |
|
|
| if not return_dict: |
| return (hidden_states, ) + encoder_outputs[1:] |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| def feature_extract( |
| self, |
| input_values, |
| seq_len, |
| ): |
| extract_features = self.feature_extractor(input_values) |
| extract_features = extract_features.transpose(1, 2) |
| extract_features = linear_interpolation(extract_features, seq_len=seq_len) |
|
|
| return extract_features |
|
|
| def encode( |
| self, |
| extract_features, |
| attention_mask=None, |
| mask_time_indices=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| |
|
|
|
|
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if attention_mask is not None: |
| |
| attention_mask = self._get_feature_vector_attention_mask( |
| extract_features.shape[1], attention_mask, add_adapter=False |
| ) |
| |
|
|
| hidden_states, extract_features = self.feature_projection(extract_features) |
| hidden_states = self._mask_hidden_states( |
| hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask |
| ) |
|
|
| encoder_outputs = self.encoder( |
| hidden_states, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| if self.adapter is not None: |
| hidden_states = self.adapter(hidden_states) |
|
|
| if not return_dict: |
| return (hidden_states, ) + encoder_outputs[1:] |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|