ydqmkkx commited on
Commit
ca68ef2
·
1 Parent(s): b841749
Files changed (1) hide show
  1. models/__init__.py +3 -0
models/__init__.py CHANGED
@@ -59,6 +59,9 @@ class GibbsTTS(nn.Module):
59
 
60
  prompt_wav, sr = torchaudio.load(prompt_audio)
61
  prompt_wav = self.resampler(prompt_wav.to(self.device), sr).unsqueeze(0)
 
 
 
62
  prompt_token = self.codec.encode(prompt_wav)
63
 
64
  ratio = prompt_token.shape[1] / len(prompt_phone)
 
59
 
60
  prompt_wav, sr = torchaudio.load(prompt_audio)
61
  prompt_wav = self.resampler(prompt_wav.to(self.device), sr).unsqueeze(0)
62
+
63
+ if prompt_wav.shape[1] > 1:
64
+ prompt_wav = prompt_wav.mean(dim=1, keepdim=True)
65
  prompt_token = self.codec.encode(prompt_wav)
66
 
67
  ratio = prompt_token.shape[1] / len(prompt_phone)