diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index 2acbb08..f8e4b50 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -85,7 +85,7 @@ class ConviSTFT(ConvFFT): input = torch.cat([real, imag], 1) out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 - coeff.to(self.device) + coeff.to(input.device) coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) out = out / (coeff + 1e-8) pad = self.window_len - self.hop_size