ensure num_channels
This commit is contained in:
parent
77699ce7f9
commit
6573bc4c5e
|
|
@ -261,6 +261,14 @@ class DCCRN(Model):
|
|||
|
||||
def forward(self, waveform):
|
||||
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(1)
|
||||
|
||||
if waveform.size(1) != self.hparams.num_channels:
|
||||
raise ValueError(
|
||||
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||
)
|
||||
|
||||
waveform_stft = self.stft(waveform)
|
||||
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
||||
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
||||
|
|
|
|||
|
|
@ -204,9 +204,9 @@ class Demucs(Model):
|
|||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(1)
|
||||
|
||||
if waveform.size(1) != 1:
|
||||
raise TypeError(
|
||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||
if waveform.size(1) != self.hparams.num_channels:
|
||||
raise ValueError(
|
||||
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||
)
|
||||
if self.normalize:
|
||||
waveform = waveform.mean(dim=1, keepdim=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue