Spaces:
Running
on
Zero
Running
on
Zero
Update cosyvoice/hifigan/generator.py
Browse files
cosyvoice/hifigan/generator.py
CHANGED
|
@@ -672,6 +672,7 @@ class CausalHiFTGenerator(HiFTGenerator):
|
|
| 672 |
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
|
| 673 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 674 |
if finalize is True:
|
|
|
|
| 675 |
x = self.conv_pre(x.cuda())
|
| 676 |
else:
|
| 677 |
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
|
|
@@ -713,8 +714,7 @@ class CausalHiFTGenerator(HiFTGenerator):
|
|
| 713 |
@torch.inference_mode()
|
| 714 |
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
| 715 |
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
| 716 |
-
self.f0_predictor.to(
|
| 717 |
-
f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
|
| 718 |
# f0->source
|
| 719 |
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 720 |
s, _, _ = self.m_source(s)
|
|
|
|
| 672 |
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
|
| 673 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 674 |
if finalize is True:
|
| 675 |
+
print('weight device {}'.format(self.conv_pre.weight.device))
|
| 676 |
x = self.conv_pre(x.cuda())
|
| 677 |
else:
|
| 678 |
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
|
|
|
|
| 714 |
@torch.inference_mode()
|
| 715 |
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
| 716 |
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
| 717 |
+
f0 = self.f0_predictor(speech_feat, finalize=finalize).to(speech_feat)
|
|
|
|
| 718 |
# f0->source
|
| 719 |
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 720 |
s, _, _ = self.m_source(s)
|