shethjenil commited on
Commit
da9e96d
·
verified ·
1 Parent(s): a705b55

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +7 -10
modeling_conformer.py CHANGED
@@ -61,10 +61,9 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
61
  self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
62
  return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
63
 
64
- def calc_length(self, lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
65
- add_pad = all_paddings - kernel_size
66
  for _ in range(repeat_num):
67
- lengths = (lengths + add_pad) // stride + 1
68
  return lengths
69
 
70
  def preprocessing(self, x):
@@ -118,13 +117,11 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
118
  cx = torch.where(mask, cx, cx_prev)
119
  last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
120
 
121
- pos = lengths.unsqueeze(1).clamp(max=max_len - 1)
122
- fill_t = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), torch.full_like(n.unsqueeze(1), pad))
123
- fill_s = torch.where(emitted.unsqueeze(1), t_sec, torch.full_like(t_sec, -1.0))
124
-
125
- tokens = tokens.scatter(1, pos, fill_t)
126
- starts = starts.scatter(1, pos, fill_s)
127
- lengths = lengths + emitted.long()
128
 
129
  return tokens, starts, lengths
130
 
 
61
  self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
62
  return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
63
 
64
+ def calc_length(self, lengths, padding=1, kernel_size=3, stride=2, repeat_num=1):
 
65
  for _ in range(repeat_num):
66
+ lengths = (lengths + 2 * padding - kernel_size) // stride + 1
67
  return lengths
68
 
69
  def preprocessing(self, x):
 
117
  cx = torch.where(mask, cx, cx_prev)
118
  last = torch.where(emitted.unsqueeze(1), n.unsqueeze(1), last)
119
 
120
+ if emitted.any():
121
+ idx = lengths[emitted].unsqueeze(1).clamp(max=max_len - 1)
122
+ tokens[emitted] = tokens[emitted].scatter(1, idx, n[emitted].unsqueeze(1))
123
+ starts[emitted] = starts[emitted].scatter(1, idx, t_sec[emitted])
124
+ lengths[emitted] += 1
 
 
125
 
126
  return tokens, starts, lengths
127