Update modeling_conformer.py
Browse files- modeling_conformer.py +7 -10
modeling_conformer.py
CHANGED
|
@@ -61,10 +61,9 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 61 |
self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
|
| 62 |
return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
|
| 63 |
|
| 64 |
-
def calc_length(self, lengths,
|
| 65 |
-
add_pad = all_paddings - kernel_size
|
| 66 |
for _ in range(repeat_num):
|
| 67 |
-
lengths = (lengths +
|
| 68 |
return lengths
|
| 69 |
|
| 70 |
def preprocessing(self, x):
|
|
@@ -118,13 +117,11 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 118 |
cx = torch.where(mask, cx, cx_prev)
|
| 119 |
last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
starts = starts.scatter(1, pos, fill_s)
|
| 127 |
-
lengths = lengths + emitted.long()
|
| 128 |
|
| 129 |
return tokens, starts, lengths
|
| 130 |
|
|
|
|
| 61 |
self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
|
| 62 |
return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
|
| 63 |
|
| 64 |
+
def calc_length(self, lengths, padding=1, kernel_size=3, stride=2, repeat_num=1):
|
|
|
|
| 65 |
for _ in range(repeat_num):
|
| 66 |
+
lengths = (lengths + 2 * padding - kernel_size) // stride + 1
|
| 67 |
return lengths
|
| 68 |
|
| 69 |
def preprocessing(self, x):
|
|
|
|
| 117 |
cx = torch.where(mask, cx, cx_prev)
|
| 118 |
last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
|
| 119 |
|
| 120 |
+
if emitted.any():
|
| 121 |
+
idx = lengths[emitted].unsqueeze(1).clamp(max=max_len - 1)
|
| 122 |
+
tokens[emitted] = tokens[emitted].scatter(1, idx, n[emitted].unsqueeze(1))
|
| 123 |
+
starts[emitted] = starts[emitted].scatter(1, idx, t_sec[emitted])
|
| 124 |
+
lengths[emitted] += 1
|
|
|
|
|
|
|
| 125 |
|
| 126 |
return tokens, starts, lengths
|
| 127 |
|