ensure num_channels
This commit is contained in:
parent
77699ce7f9
commit
6573bc4c5e
|
|
@ -261,6 +261,14 @@ class DCCRN(Model):
|
||||||
|
|
||||||
def forward(self, waveform):
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
|
if waveform.size(1) != self.hparams.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
waveform_stft = self.stft(waveform)
|
waveform_stft = self.stft(waveform)
|
||||||
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
||||||
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
||||||
|
|
|
||||||
|
|
@ -204,9 +204,9 @@ class Demucs(Model):
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1) != 1:
|
if waveform.size(1) != self.hparams.num_channels:
|
||||||
raise TypeError(
|
raise ValueError(
|
||||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||||
)
|
)
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
waveform = waveform.mean(dim=1, keepdim=True)
|
waveform = waveform.mean(dim=1, keepdim=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue