From 6573bc4c5e7b1798c8747c499f65210be9e7993d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 11:33:00 +0530 Subject: [PATCH] ensure num_channels --- enhancer/models/dccrn.py | 8 ++++++++ enhancer/models/demucs.py | 6 +++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index 72d8a23..7b1e5b1 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -261,6 +261,14 @@ class DCCRN(Model): def forward(self, waveform): + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" + ) + waveform_stft = self.stft(waveform) real = waveform_stft[:, : self.stft.nfft // 2 + 1] imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index e5fa945..fafb84e 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -204,9 +204,9 @@ class Demucs(Model): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1) != 1: - raise TypeError( - f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" ) if self.normalize: waveform = waveform.mean(dim=1, keepdim=True)