ensure num_channels

This commit is contained in:
shahules786 2022-11-07 11:33:00 +05:30
parent 77699ce7f9
commit 6573bc4c5e
2 changed files with 11 additions and 3 deletions

View File

@ -261,6 +261,14 @@ class DCCRN(Model):
def forward(self, waveform): def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
waveform_stft = self.stft(waveform) waveform_stft = self.stft(waveform)
real = waveform_stft[:, : self.stft.nfft // 2 + 1] real = waveform_stft[:, : self.stft.nfft // 2 + 1]
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]

View File

@ -204,9 +204,9 @@ class Demucs(Model):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1) != 1: if waveform.size(1) != self.hparams.num_channels:
raise TypeError( raise ValueError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
) )
if self.normalize: if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True) waveform = waveform.mean(dim=1, keepdim=True)