diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 571a915..115f63e 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -108,13 +108,15 @@ class Demucs(Model): sampling_rate = 16000, lr:float=1e-3, dataset:Optional[EnhancerDataset]=None, - loss:Union[str, List] = "mse" + loss:Union[str, List] = "mse", + metric:Union[str, List] = "mse" + ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss) + dataset=dataset,duration=duration,loss=loss, metric=metric) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) lstm = merge_dict(self.LSTM_DEFAULTS,lstm) @@ -151,16 +153,16 @@ class Demucs(Model): bidirectional=lstm["bidirectional"] ) - def forward(self,mixed_signal): + def forward(self,waveform): - if mixed_signal.dim() == 2: - mixed_signal = mixed_signal.unsqueeze(1) + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) - if mixed_signal.size(1)!=1: - raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels") + if waveform.size(1)!=1: + raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") - length = mixed_signal.shape[-1] - x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) + length = waveform.shape[-1] + x = F.pad(waveform, (0,self.get_padding_length(length) - length)) if self.hparams.resample>1: x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, target_sr=int(self.hparams.sampling_rate * self.hparams.resample))