diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index 72d8a23..7b1e5b1 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -261,6 +261,14 @@ class DCCRN(Model): def forward(self, waveform): + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" + ) + waveform_stft = self.stft(waveform) real = waveform_stft[:, : self.stft.nfft // 2 + 1] imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index e5fa945..fafb84e 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -204,9 +204,9 @@ class Demucs(Model): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1) != 1: - raise TypeError( - f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" ) if self.normalize: waveform = waveform.mean(dim=1, keepdim=True)