From 40e2d6e0b02553c5f9eff361959a08f068f37ccf Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 12:32:58 +0530 Subject: [PATCH 01/24] change to mapstyle --- enhancer/data/dataset.py | 86 ++++++++++++---------------------------- 1 file changed, 26 insertions(+), 60 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 7f7ae67..08b402f 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,14 +1,12 @@ import math import multiprocessing import os -import random -from itertools import chain, cycle from typing import Optional import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data import DataLoader, Dataset from enhancer.data.fileprocessor import Fileprocessor from enhancer.utils import check_files @@ -16,13 +14,15 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng +LARGE_NUM = 2147483647 -class TrainDataset(IterableDataset): + +class TrainDataset(Dataset): def __init__(self, dataset): self.dataset = dataset - def __iter__(self): - return self.dataset.train__iter__() + def __getitem__(self, idx): + return self.dataset.train__getitem__(idx) def __len__(self): return self.dataset.train__len__() @@ -135,16 +135,11 @@ class TaskDataset(pl.LightningDataModule): def prepare_traindata(self, data): train_data = [] for item in data: - samples_metadata = [] clean, noisy, total_dur = item.values() num_segments = self.get_num_segments( total_dur, self.duration, self.stride ) - for index in range(num_segments): - start = index * self.stride - samples_metadata.append( - ({"clean": clean, "noisy": noisy}, start) - ) + samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments) train_data.append(samples_metadata) return train_data @@ -175,31 +170,20 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - output = {"noisy": [], "clean": []} - for item in batch: - output["noisy"].append(item["noisy"]) - output["clean"].append(item["clean"]) + raise NotImplementedError("Not implemented") - output["clean"] = torch.stack(output["clean"], dim=0) - output["noisy"] = torch.stack(output["noisy"], dim=0) - return output - - def worker_init_fn(self, _): - worker_info = torch.utils.data.get_worker_info() - dataset = worker_info.dataset - worker_id = worker_info.id - split_size = len(dataset.dataset.train_data) // worker_info.num_workers - dataset.data = dataset.dataset.train_data[ - worker_id * split_size : (worker_id + 1) * split_size - ] + @property + def generator(self): + generator = torch.Generator() + generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM) + return generator def train_dataloader(self): return DataLoader( TrainDataset(self), - batch_size=None, + batch_size=self.batch_size, num_workers=self.num_workers, - collate_fn=self.train_collatefn, - worker_init_fn=self.worker_init_fn, + generator=self.generator, ) def val_dataloader(self): @@ -280,35 +264,16 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) - def random_sample(self, train_data): - return random.sample(train_data, len(train_data)) + def train__getitem__(self, idx): - def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) - train_data = rng.sample(self.train_data, len(self.train_data)) - return zip( - *[ - self.get_stream(self.random_sample(train_data)) - for i in range(self.batch_size) - ] - ) - - def get_stream(self, data): - return chain.from_iterable(map(self.process_data, cycle(data))) - - def process_data(self, data): - for item in data: - yield self.prepare_segment(*item) - - @staticmethod - def get_num_segments(file_duration, duration, stride): - - if file_duration < duration: - num_segments = 1 - else: - num_segments = math.ceil((file_duration - duration) / stride) + 1 - - return num_segments + for filedict, num_samples in self.train_data: + if idx >= num_samples: + idx -= num_samples + continue + start = 0 + if self.duration is not None: + start = idx * self.stride + return self.prepare_segment(filedict, start) def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) @@ -348,7 +313,8 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - return sum([len(item) for item in self.train_data]) // (self.batch_size) + _, num_examples = list(zip(*self.train_data)) + return sum(num_examples) def val__len__(self): return len(self._validation) From ea5c78798add582922f36fd6bbc38194776b13f8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 12:33:38 +0530 Subject: [PATCH 02/24] model assigment' --- enhancer/models/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index dc5219d..3b60b85 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -113,6 +113,8 @@ class Model(pl.LightningModule): if stage == "fit": torch.cuda.empty_cache() self.dataset.setup(stage) + self.dataset.model = self + print( "Total train duration", self.dataset.train_dataloader().dataset.__len__() @@ -134,7 +136,6 @@ class Model(pl.LightningModule): / 60, "minutes", ) - self.dataset.model = self def train_dataloader(self): return self.dataset.train_dataloader() From fc41de1530b926785624d40dbc812042b4a7fd75 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 12:36:43 +0530 Subject: [PATCH 03/24] VCTK + DEMUCS --- enhancer/cli/train_config/dataset/Vctk.yaml | 4 ++-- enhancer/cli/train_config/trainer/default.yaml | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 0e1f38f..a5a5b5f 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -4,8 +4,8 @@ root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 stride : 0.5 sampling_rate: 16000 -batch_size: 4 -valid_minutes : 1 +batch_size: 128 +valid_minutes : 15 files: train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 01914e4..958c418 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -1,5 +1,5 @@ _target_: pytorch_lightning.Trainer -accelerator: auto +accelerator: gpu accumulate_grad_batches: 1 amp_backend: native auto_lr_find: True @@ -9,7 +9,7 @@ benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: 1 +devices: 2 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True @@ -23,7 +23,7 @@ limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 log_every_n_steps: 50 -max_epochs: 3 +max_epochs: 200 max_steps: -1 max_time: null min_epochs: 1 @@ -38,7 +38,7 @@ precision: 32 profiler: null reload_dataloaders_every_n_epochs: 0 replace_sampler_ddp: True -strategy: null +strategy: ddp sync_batchnorm: False tpu_cores: null track_grad_norm: -1 From 3128fed71e68e13c7624fe661778cfe45b5fbe4a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 12:38:20 +0530 Subject: [PATCH 04/24] params --- enhancer/cli/train_config/hyperparameters/default.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 7e4cda3..4d8b391 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,7 +1,7 @@ -loss : mse -metric : mae -lr : 0.0001 -ReduceLr_patience : 5 -ReduceLr_factor : 0.1 +loss : mae +metric : [stoi,pesq,si-sdr] +lr : 0.0003 +ReduceLr_patience : 10 +ReduceLr_factor : 0.5 min_lr : 0.000001 EarlyStopping_factor : 10 From 460366bd8b4e5e3f0e90176d542635680edc1281 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 17:15:17 +0530 Subject: [PATCH 05/24] min conf acc ablation study --- enhancer/cli/train_config/model/Demucs.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml index 3c565ee..513e603 100644 --- a/enhancer/cli/train_config/model/Demucs.yaml +++ b/enhancer/cli/train_config/model/Demucs.yaml @@ -1,11 +1,11 @@ _target_: enhancer.models.demucs.Demucs num_channels: 1 -resample: 2 +resample: 4 sampling_rate : 16000 encoder_decoder: - depth: 5 - initial_output_channels: 32 + depth: 4 + initial_output_channels: 64 kernel_size: 8 stride: 4 growth_factor: 2 From 97b4a61d9c96bf7b2568312a6f55e5cf5ffff14c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 19:07:53 +0530 Subject: [PATCH 06/24] half BS --- enhancer/cli/train_config/dataset/Vctk.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index a5a5b5f..0acbb36 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -4,7 +4,7 @@ root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 stride : 0.5 sampling_rate: 16000 -batch_size: 128 +batch_size: 64 valid_minutes : 15 files: train_clean : clean_trainset_28spk_wav From 101ee563cb553db16c145a1d36ea81d1f76b63c5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 23 Oct 2022 19:30:46 +0530 Subject: [PATCH 07/24] decrease precision --- enhancer/cli/train_config/trainer/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 958c418..ca866fb 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -34,7 +34,7 @@ num_nodes: 1 num_processes: 1 num_sanity_val_steps: 2 overfit_batches: 0.0 -precision: 32 +precision: 16 profiler: null reload_dataloaders_every_n_epochs: 0 replace_sampler_ddp: True From 75ebef24621560d4a6b8b289c5b886c373326342 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 24 Oct 2022 10:01:54 +0530 Subject: [PATCH 08/24] Waveunet w/o stride --- enhancer/cli/train_config/config.yaml | 2 +- enhancer/cli/train_config/dataset/Vctk.yaml | 5 ++--- enhancer/cli/train_config/hyperparameters/default.yaml | 4 ++-- enhancer/cli/train_config/model/WaveUnet.yaml | 2 +- enhancer/cli/train_config/trainer/default.yaml | 2 +- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml index 8d0ab14..c0b2cf6 100644 --- a/enhancer/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 0acbb36..2f22146 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -1,10 +1,9 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 -duration : 4.5 -stride : 0.5 +duration : 2 sampling_rate: 16000 -batch_size: 64 +batch_size: 128 valid_minutes : 15 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 4d8b391..0291c8e 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,6 +1,6 @@ -loss : mae +loss : mse metric : [stoi,pesq,si-sdr] -lr : 0.0003 +lr : 0.001 ReduceLr_patience : 10 ReduceLr_factor : 0.5 min_lr : 0.000001 diff --git a/enhancer/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml index d641bcd..29d48c7 100644 --- a/enhancer/cli/train_config/model/WaveUnet.yaml +++ b/enhancer/cli/train_config/model/WaveUnet.yaml @@ -1,5 +1,5 @@ _target_: enhancer.models.waveunet.WaveUnet num_channels : 1 -depth : 12 +depth : 9 initial_output_channels: 24 sampling_rate : 16000 diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index ca866fb..958c418 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -34,7 +34,7 @@ num_nodes: 1 num_processes: 1 num_sanity_val_steps: 2 overfit_batches: 0.0 -precision: 16 +precision: 32 profiler: null reload_dataloaders_every_n_epochs: 0 replace_sampler_ddp: True From 5dc5fd8f901e916aabcaa16ba173a353a5cd582e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 24 Oct 2022 21:15:25 +0530 Subject: [PATCH 09/24] default stride None --- enhancer/data/dataset.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 08b402f..ab2f1ce 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -14,6 +14,9 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng +# from torch_audiomentations import Compose + + LARGE_NUM = 2147483647 @@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule): matching_function=None, batch_size=32, num_workers: Optional[int] = None, + # augmentations: Optional[Compose] = None, ): super().__init__() @@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule): else: raise ValueError("valid_minutes must be greater than 0") + # self.augmentations = augmentations + def setup(self, stage: Optional[str] = None): """ prepare train/validation/test data splits @@ -161,7 +167,9 @@ class TaskDataset(pl.LightningDataModule): if total_dur < self.duration: metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) else: - num_segments = round(total_dur / self.duration) + num_segments = self.get_num_segments( + total_dur, self.duration, self.duration + ) for index in range(num_segments): start_time = index * self.duration metadata.append( @@ -175,8 +183,11 @@ class TaskDataset(pl.LightningDataModule): @property def generator(self): generator = torch.Generator() - generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM) - return generator + if hasattr(self, "model"): + seed = self.model.current_epoch + LARGE_NUM + else: + seed = LARGE_NUM + return generator.manual_seed(seed) def train_dataloader(self): return DataLoader( @@ -235,11 +246,12 @@ class EnhancerDataset(TaskDataset): files: Files, valid_minutes=5.0, duration=1.0, - stride=0.5, + stride=None, sampling_rate=48000, matching_function=None, batch_size=32, num_workers: Optional[int] = None, + # augmentations: Optional[Compose] = None, ): super().__init__( @@ -252,6 +264,7 @@ class EnhancerDataset(TaskDataset): matching_function=matching_function, batch_size=batch_size, num_workers=num_workers, + # augmentations=augmentations, ) self.sampling_rate = sampling_rate From 542ab23d8a9d63ddddefa11e94634f641a4ecef6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 24 Oct 2022 21:50:30 +0530 Subject: [PATCH 10/24] add torch-augmentations --- enhancer/data/dataset.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index ab2f1ce..1e0ec04 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -7,6 +7,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor from enhancer.utils import check_files @@ -14,9 +15,6 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng -# from torch_audiomentations import Compose - - LARGE_NUM = 2147483647 @@ -66,7 +64,7 @@ class TaskDataset(pl.LightningDataModule): matching_function=None, batch_size=32, num_workers: Optional[int] = None, - # augmentations: Optional[Compose] = None, + augmentations: Optional[Compose] = None, ): super().__init__() @@ -86,7 +84,7 @@ class TaskDataset(pl.LightningDataModule): else: raise ValueError("valid_minutes must be greater than 0") - # self.augmentations = augmentations + self.augmentations = augmentations def setup(self, stage: Optional[str] = None): """ @@ -178,7 +176,25 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - raise NotImplementedError("Not implemented") + + output = {"clean": [], "noisy": []} + for item in batch: + output["clean"].append(item["clean"]) + output["noisy"].append(item["noisy"]) + + output["clean"] = torch.stack(output["clean"], dim=0) + output["noisy"] = torch.stack(output["noisy"], dim=0) + + if self.augmentations is not None: + output["clean"] = self.augmentations( + output["clean"], sample_rate=self.sampling_rate + ) + self.augmentations.freeze_parameters() + output["noisy"] = self.augmentations( + output["noisy"], sample_rate=self.sampling_rate + ) + + return output @property def generator(self): @@ -251,7 +267,7 @@ class EnhancerDataset(TaskDataset): matching_function=None, batch_size=32, num_workers: Optional[int] = None, - # augmentations: Optional[Compose] = None, + augmentations: Optional[Compose] = None, ): super().__init__( @@ -264,7 +280,7 @@ class EnhancerDataset(TaskDataset): matching_function=matching_function, batch_size=batch_size, num_workers=num_workers, - # augmentations=augmentations, + augmentations=augmentations, ) self.sampling_rate = sampling_rate From 03d0dc57fc2c3617db0e21a1a43bcbed1cb74029 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 24 Oct 2022 22:13:19 +0530 Subject: [PATCH 11/24] add torch audiomentations --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index fa5e41c..cf8992d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,5 +14,6 @@ scikit-learn>=1.1.2 scipy>=1.9.1 soundfile>=0.11.0 torch>=1.12.1 +torch-audiomentations==0.11.0 torchaudio>=0.12.1 tqdm>=4.64.1 From cdffe5c4852e51951ba9a775f9eb76c72cf61e21 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 25 Oct 2022 10:57:07 +0530 Subject: [PATCH 12/24] DEMUCS w/o stride --- enhancer/cli/train_config/dataset/Vctk.yaml | 4 ++-- enhancer/cli/train_config/hyperparameters/default.yaml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 2f22146..3f8def6 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -1,9 +1,9 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 -duration : 2 +duration : 4.5 sampling_rate: 16000 -batch_size: 128 +batch_size: 32 valid_minutes : 15 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 0291c8e..b6bba46 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,7 +1,7 @@ loss : mse metric : [stoi,pesq,si-sdr] -lr : 0.001 -ReduceLr_patience : 10 -ReduceLr_factor : 0.5 +lr : 0.0001 +ReduceLr_patience : 5 +ReduceLr_factor : 0.2 min_lr : 0.000001 EarlyStopping_factor : 10 From d1bafb3dc637aa51493656f677a15434af17b742 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 25 Oct 2022 12:43:54 +0530 Subject: [PATCH 13/24] add augmentations --- enhancer/cli/train.py | 9 ++++++++- enhancer/data/dataset.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 08f4d3e..6d5f182 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch_audiomentations import BandPassFilter, Compose, Shift os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -25,8 +26,14 @@ def main(config: DictConfig): ) parameters = config.hyperparameters + apply_augmentations = Compose( + [ + Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5), + BandPassFilter(p=0.5), + ] + ) - dataset = instantiate(config.dataset) + dataset = instantiate(config.dataset, augmentations=apply_augmentations) model = instantiate( config.model, dataset=dataset, diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 1e0ec04..f71d612 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -211,6 +211,7 @@ class TaskDataset(pl.LightningDataModule): batch_size=self.batch_size, num_workers=self.num_workers, generator=self.generator, + collate_fn=self.train_collatefn, ) def val_dataloader(self): From b070613b647dfe6637fb52230e8cc14948a36d6b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 25 Oct 2022 12:48:37 +0530 Subject: [PATCH 14/24] config" --- enhancer/cli/train_config/config.yaml | 2 +- enhancer/cli/train_config/dataset/Vctk.yaml | 1 + enhancer/cli/train_config/mlflow/experiment.yaml | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml index c0b2cf6..8d0ab14 100644 --- a/enhancer/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : WaveUnet + - model : Demucs - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 3f8def6..c33d29a 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -2,6 +2,7 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 +stride : 2 sampling_rate: 16000 batch_size: 32 valid_minutes : 15 diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml index e8893f6..d597333 100644 --- a/enhancer/cli/train_config/mlflow/experiment.yaml +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ experiment_name : shahules/enhancer -run_name : baseline +run_name : Demucs + Vtck with stride + augmentations From 4acad6ede8c27fb4368940e771618ebdf297504d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 25 Oct 2022 15:10:13 +0530 Subject: [PATCH 15/24] fix augmentation --- enhancer/data/dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f71d612..34ecb8f 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -186,12 +186,14 @@ class TaskDataset(pl.LightningDataModule): output["noisy"] = torch.stack(output["noisy"], dim=0) if self.augmentations is not None: + noise = output["noisy"] - output["clean"] output["clean"] = self.augmentations( output["clean"], sample_rate=self.sampling_rate ) self.augmentations.freeze_parameters() - output["noisy"] = self.augmentations( - output["noisy"], sample_rate=self.sampling_rate + output["noisy"] = ( + self.augmentations(noise, sample_rate=self.sampling_rate) + + output["clean"] ) return output From 58de41598e90d6f62fc62bbe04f4a24636453031 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 25 Oct 2022 15:10:36 +0530 Subject: [PATCH 16/24] change matrix --- enhancer/cli/train_config/hyperparameters/default.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index b6bba46..1782ea9 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,6 +1,6 @@ -loss : mse +loss : mae metric : [stoi,pesq,si-sdr] -lr : 0.0001 +lr : 0.0003 ReduceLr_patience : 5 ReduceLr_factor : 0.2 min_lr : 0.000001 From 04782ba6e931ef8a87bc9197bccc0859c6fe9774 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 10:26:27 +0530 Subject: [PATCH 17/24] fix optimizer scheduler --- enhancer/cli/train.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 6d5f182..131db4f 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,10 +4,14 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch_audiomentations import BandPassFilter, Compose, Shift +from torch_audiomentations import Compose, Shift os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -29,7 +33,6 @@ def main(config: DictConfig): apply_augmentations = Compose( [ Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5), - BandPassFilter(p=0.5), ] ) @@ -52,6 +55,8 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) + callbacks.append(LearningRateMonitor(logging_interval="epoch")) + if parameters.get("Early_stop", False): early_stopping = EarlyStopping( monitor="val_loss", @@ -63,11 +68,11 @@ def main(config: DictConfig): ) callbacks.append(early_stopping) - def configure_optimizer(self): + def configure_optimizers(self): optimizer = instantiate( config.optimizer, lr=parameters.get("lr"), - parameters=self.parameters(), + params=self.parameters(), ) scheduler = ReduceLROnPlateau( optimizer=optimizer, @@ -77,9 +82,13 @@ def main(config: DictConfig): min_lr=parameters.get("min_lr", 1e-6), patience=parameters.get("ReduceLr_patience", 3), ) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', + } - model.configure_parameters = MethodType(configure_optimizer, model) + model.configure_optimizers = MethodType(configure_optimizers, model) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) From 24a06ba9be6ef5d85d521587d66a4985883ba957 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 10:27:23 +0530 Subject: [PATCH 18/24] rename loss --- enhancer/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index 2150699..fc8afae 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -66,7 +66,7 @@ class Si_SDR: "Invalid reduction, valid options are sum, mean, None" ) self.higher_better = False - self.name = "Si-SDR" + self.name = "si-sdr" def __call__(self, prediction: torch.Tensor, target: torch.Tensor): From f07c8741ba650d67f2c1f1dc01c449ff9913c49a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 11:59:58 +0530 Subject: [PATCH 19/24] fix resampling --- enhancer/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 9e9ce32..d151ef8 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -70,7 +70,7 @@ class Audio: if sampling_rate: audio = self.__class__.resample_audio( - audio, self.sampling_rate, sampling_rate + audio, sampling_rate, self.sampling_rate ) if self.return_tensor: return torch.tensor(audio) From ee40259a8df9cf4467a14b7b16ba9ee742c660cd Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 12:00:57 +0530 Subject: [PATCH 20/24] fix iterator --- enhancer/data/dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 34ecb8f..f05fd6b 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -302,10 +302,11 @@ class EnhancerDataset(TaskDataset): if idx >= num_samples: idx -= num_samples continue - start = 0 - if self.duration is not None: - start = idx * self.stride - return self.prepare_segment(filedict, start) + else: + start = 0 + if self.duration is not None: + start = idx * self.stride + return self.prepare_segment(filedict, start) def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) From 1edc10e9f590a40ac54a35cba5b72dbdc87c2860 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 12:01:19 +0530 Subject: [PATCH 21/24] time shift --- enhancer/cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 131db4f..5562cfd 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -32,7 +32,7 @@ def main(config: DictConfig): parameters = config.hyperparameters apply_augmentations = Compose( [ - Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5), + Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), ] ) From c51dea68859f20458801b2e4c8fd0a73f690175a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 21:46:19 +0530 Subject: [PATCH 22/24] revert to torchmetric pesq --- enhancer/loss.py | 14 +++++--------- requirements.txt | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index fc8afae..32b30cf 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch import torch.nn as nn -from pesq import pesq +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility @@ -122,20 +122,16 @@ class Pesq: self.sr = sr self.name = "pesq" self.mode = mode + self.pesq = PerceptualEvaluationSpeechQuality( + fs=self.sr, mode=self.mode + ) def __call__(self, prediction: torch.Tensor, target: torch.Tensor): pesq_values = [] for pred, target_ in zip(prediction, target): try: - pesq_values.append( - pesq( - self.sr, - target_.squeeze().detach().cpu().numpy(), - pred.squeeze().detach().cpu().numpy(), - self.mode, - ) - ) + pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) except Exception as e: logging.warning(f"{e} error occured while calculating PESQ") return torch.tensor(np.mean(pesq_values)) diff --git a/requirements.txt b/requirements.txt index cf8992d..fb54920 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ joblib>=1.2.0 librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 -git+https://github.com/ludlows/python-pesq#egg=pesq +pesq==0.0.4 protobuf>=3.19.6 pystoi==0.3.3 pytest-lazy-fixture>=0.6.3 From 47bbee2c32ee87ed006630ba984395ebc2185fdf Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 21:47:29 +0530 Subject: [PATCH 23/24] rmv augmentations --- enhancer/cli/train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 5562cfd..c00c024 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -11,7 +11,8 @@ from pytorch_lightning.callbacks import ( ) from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch_audiomentations import Compose, Shift + +# from torch_audiomentations import Compose, Shift os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -30,13 +31,13 @@ def main(config: DictConfig): ) parameters = config.hyperparameters - apply_augmentations = Compose( - [ - Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), - ] - ) + # apply_augmentations = Compose( + # [ + # Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), + # ] + # ) - dataset = instantiate(config.dataset, augmentations=apply_augmentations) + dataset = instantiate(config.dataset, augmentations=None) model = instantiate( config.model, dataset=dataset, From e1963ff001fe7785ca809a0df48799077938c8fc Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 27 Oct 2022 15:19:02 +0530 Subject: [PATCH 24/24] split validation criterion --- enhancer/data/dataset.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f05fd6b..dac2c50 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,8 +1,10 @@ import math import multiprocessing import os +from pathlib import Path from typing import Optional +import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -119,16 +121,29 @@ class TaskDataset(pl.LightningDataModule): ): valid_minutes *= 60 - valid_min_now = 0.0 + valid_sec_now = 0.0 valid_indices = [] - random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state) - rng.shuffle(random_indices) - i = 0 - while valid_min_now <= valid_minutes: - valid_indices.append(random_indices[i]) - valid_min_now += data[random_indices[i]]["duration"] - i += 1 + all_speakers = np.unique( + [ + (Path(file["clean"]).name.split("_")[0], file["duration"]) + for file in data + ] + ) + possible_indices = list(range(0, len(all_speakers))) + rng = create_unique_rng(len(all_speakers)) + + while valid_sec_now <= valid_minutes: + speaker_index = rng.choice(possible_indices) + possible_indices.remove(speaker_index) + speaker_name = all_speakers[speaker_index] + file_indices = [ + i + for i, file in enumerate(data) + if speaker_name == Path(file["clean"]).name.split("_")[0] + ] + for i in file_indices: + valid_indices.append(i) + valid_sec_now += data[i]["duration"] train_data = [ item for i, item in enumerate(data) if i not in valid_indices