diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index c33d29a..3f7252c 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 -stride : 2 +stride : 0.5 sampling_rate: 16000 batch_size: 32 -valid_minutes : 15 +valid_minutes : 25 files: train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 1782ea9..63eeece 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,7 +1,8 @@ loss : mae metric : [stoi,pesq,si-sdr] lr : 0.0003 -ReduceLr_patience : 5 -ReduceLr_factor : 0.2 +ReduceLr_patience : 10 +Early_stop : True +ReduceLr_factor : 0.5 min_lr : 0.000001 -EarlyStopping_factor : 10 +EarlyStopping_patience : 10 diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml index d597333..8d76f15 100644 --- a/enhancer/cli/train_config/mlflow/experiment.yaml +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ experiment_name : shahules/enhancer -run_name : Demucs + Vtck with stride + augmentations +run_name : Demucs(ablation_study) diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 8bdf60f..988c6e1 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -23,7 +23,7 @@ limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 log_every_n_steps: 50 -max_epochs: 200 +max_epochs: 250 max_steps: -1 max_time: null min_epochs: 1 diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index dac2c50..b5bbadf 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -136,6 +136,7 @@ class TaskDataset(pl.LightningDataModule): speaker_index = rng.choice(possible_indices) possible_indices.remove(speaker_index) speaker_name = all_speakers[speaker_index] + print(f"Selected f{speaker_name} for valid") file_indices = [ i for i, file in enumerate(data)