pass metric to Model

This commit is contained in:
shahules786 2022-09-24 12:47:12 +05:30
parent 71b98ba67c
commit 43261dec16
1 changed files with 11 additions and 9 deletions

View File

@ -108,13 +108,15 @@ class Demucs(Model):
sampling_rate = 16000, sampling_rate = 16000,
lr:float=1e-3, lr:float=1e-3,
dataset:Optional[EnhancerDataset]=None, dataset:Optional[EnhancerDataset]=None,
loss:Union[str, List] = "mse" loss:Union[str, List] = "mse",
metric:Union[str, List] = "mse"
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
super().__init__(num_channels=num_channels, super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr, sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss) dataset=dataset,duration=duration,loss=loss, metric=metric)
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS,lstm) lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
@ -151,16 +153,16 @@ class Demucs(Model):
bidirectional=lstm["bidirectional"] bidirectional=lstm["bidirectional"]
) )
def forward(self,mixed_signal): def forward(self,waveform):
if mixed_signal.dim() == 2: if waveform.dim() == 2:
mixed_signal = mixed_signal.unsqueeze(1) waveform = waveform.unsqueeze(1)
if mixed_signal.size(1)!=1: if waveform.size(1)!=1:
raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels") raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels")
length = mixed_signal.shape[-1] length = waveform.shape[-1]
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) x = F.pad(waveform, (0,self.get_padding_length(length) - length))
if self.hparams.resample>1: if self.hparams.resample>1:
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) target_sr=int(self.hparams.sampling_rate * self.hparams.resample))