diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml index 8d0ab14..c0b2cf6 100644 --- a/enhancer/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 0acbb36..2f22146 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -1,10 +1,9 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 -duration : 4.5 -stride : 0.5 +duration : 2 sampling_rate: 16000 -batch_size: 64 +batch_size: 128 valid_minutes : 15 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 4d8b391..0291c8e 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,6 +1,6 @@ -loss : mae +loss : mse metric : [stoi,pesq,si-sdr] -lr : 0.0003 +lr : 0.001 ReduceLr_patience : 10 ReduceLr_factor : 0.5 min_lr : 0.000001 diff --git a/enhancer/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml index d641bcd..29d48c7 100644 --- a/enhancer/cli/train_config/model/WaveUnet.yaml +++ b/enhancer/cli/train_config/model/WaveUnet.yaml @@ -1,5 +1,5 @@ _target_: enhancer.models.waveunet.WaveUnet num_channels : 1 -depth : 12 +depth : 9 initial_output_channels: 24 sampling_rate : 16000 diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index ca866fb..958c418 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -34,7 +34,7 @@ num_nodes: 1 num_processes: 1 num_sanity_val_steps: 2 overfit_batches: 0.0 -precision: 16 +precision: 32 profiler: null reload_dataloaders_every_n_epochs: 0 replace_sampler_ddp: True