diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index a6c0d34..b354f55 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -70,6 +70,8 @@ class WaveUnet(Model): loss: Union[str, List] = "mse", metric:Union[str,List] = "mse" ): + duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + sampling_rate = sampling_rate if dataset is None else dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric