pass metric to Model
This commit is contained in:
parent
71b98ba67c
commit
43261dec16
|
|
@ -108,13 +108,15 @@ class Demucs(Model):
|
||||||
sampling_rate = 16000,
|
sampling_rate = 16000,
|
||||||
lr:float=1e-3,
|
lr:float=1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset:Optional[EnhancerDataset]=None,
|
||||||
loss:Union[str, List] = "mse"
|
loss:Union[str, List] = "mse",
|
||||||
|
metric:Union[str, List] = "mse"
|
||||||
|
|
||||||
|
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
dataset=dataset,duration=duration,loss=loss)
|
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
||||||
|
|
||||||
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
||||||
|
|
@ -151,16 +153,16 @@ class Demucs(Model):
|
||||||
bidirectional=lstm["bidirectional"]
|
bidirectional=lstm["bidirectional"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,mixed_signal):
|
def forward(self,waveform):
|
||||||
|
|
||||||
if mixed_signal.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
mixed_signal = mixed_signal.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if mixed_signal.size(1)!=1:
|
if waveform.size(1)!=1:
|
||||||
raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels")
|
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels")
|
||||||
|
|
||||||
length = mixed_signal.shape[-1]
|
length = waveform.shape[-1]
|
||||||
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
|
x = F.pad(waveform, (0,self.get_padding_length(length) - length))
|
||||||
if self.hparams.resample>1:
|
if self.hparams.resample>1:
|
||||||
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
|
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
|
||||||
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
|
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue