From fffdf02b932e9147c790aa73a0acfbf31ba12616 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 15:17:58 +0530 Subject: [PATCH 1/8] valid monitor fix --- cli/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cli/train.py b/cli/train.py index 88e513a..a5c83f0 100644 --- a/cli/train.py +++ b/cli/train.py @@ -4,7 +4,7 @@ from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger - +os.environ["HYDRA_FULL_ERROR"] = "1" @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): @@ -20,14 +20,15 @@ def main(config: DictConfig): model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), loss=parameters.get("loss"), metric = parameters.get("metric")) + direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor="valid_loss",verbose=False, - mode="min",every_n_epochs=1 + dirpath="",filename="model",monitor="val_loss",verbose=False, + mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) early_stopping = EarlyStopping( - monitor="valid_loss", - mode="min", + monitor="val_loss", + mode=direction, min_delta=0.0, patience=100, strict=True, From 1f4947103f12f2f9eed97a2b93468cc1e87bcea1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 15:19:06 +0530 Subject: [PATCH 2/8] ensure loss direction --- enhancer/loss.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index 3bc6fa2..ef33161 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -1,5 +1,3 @@ -from modulefinder import Module -from turtle import forward import torch import torch.nn as nn @@ -10,6 +8,7 @@ class mean_squared_error(nn.Module): super().__init__() self.loss_fun = nn.MSELoss(reduction=reduction) + self.higher_better = False def forward(self,prediction:torch.Tensor, target: torch.Tensor): @@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module): super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) + self.higher_better = False def forward(self, prediction:torch.Tensor, target: torch.Tensor): @@ -45,6 +45,7 @@ class Si_SDR(nn.Module): self.reduction = reduction else: raise TypeError("Invalid reduction, valid options are sum, mean, None") + self.higher_better = False def forward(self,prediction:torch.Tensor, target:torch.Tensor): @@ -76,6 +77,12 @@ class Avergeloss(nn.Module): super().__init__() self.valid_losses = nn.ModuleList() + + direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] + if len(set(direction)) > 1: + raise ValueError("all cost functions should be of same nature, maximize or minimize!") + + self.higher_better = direction[0] for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) From e2f570a8d1504cc8652ab61817b26753298afcb7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 15:19:43 +0530 Subject: [PATCH 3/8] set cost property --- enhancer/models/model.py | 51 ++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 8e607ed..b1bdd86 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,3 +1,7 @@ +try: + from functools import cached_property +except ImportError: + from backports.cached_property import cached_property from importlib import import_module from huggingface_hub import cached_download, hf_hub_url import logging @@ -42,7 +46,34 @@ class Model(pl.LightningModule): if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") - + self.loss = loss + self.metric = metric + + @property + def loss(self): + return self._loss + + @loss.setter + def loss(self,loss): + + if isinstance(loss,str): + losses = [loss] + + self._loss = Avergeloss(losses) + + @property + def metric(self): + return self._metric + + @metric.setter + def metric(self,metric): + + if isinstance(metric,str): + metric = [metric] + + self._metric = Avergeloss(metric) + + @property def dataset(self): return self._dataset @@ -55,16 +86,7 @@ class Model(pl.LightningModule): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - self.loss = self.setup_loss(self.hparams.loss) - self.metric = self.setup_loss(self.hparams.metric) - - def setup_loss(self,loss): - - if isinstance(loss,str): - losses = [loss] - - return Avergeloss(losses) - + def train_dataloader(self): return self.dataset.train_dataloader() @@ -224,7 +246,12 @@ class Model(pl.LightningModule): Inference.write_output(waveform,audio,model_sampling_rate) else: - return waveform + return waveform + + @property + def valid_monitor(self): + + return "max" if self.loss.higher_better else "min" From 3c180da4447e289ac6a17f01e9006c266fd433a4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 15:20:13 +0530 Subject: [PATCH 4/8] rmv root dir --- enhancer/utils/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index e9af6a0..1bbc51d 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -2,7 +2,6 @@ from dataclasses import dataclass @dataclass class Files: - root_dir : str train_clean : str train_noisy : str test_clean : str From 04ba785eb3a56f5afa8ed0cf8bc128ee29397e8d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 15:37:04 +0530 Subject: [PATCH 5/8] add documentation --- enhancer/data/dataset.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 5749c36..fc871b8 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -91,7 +91,28 @@ class TaskDataset(pl.LightningDataModule): return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) class EnhancerDataset(TaskDataset): - """Dataset object for creating clean-noisy speech enhancement datasets""" + """ + Dataset object for creating clean-noisy speech enhancement datasets + paramters: + name : str + name of the dataset + root_dir : str + root directory of the dataset containing clean/noisy folders + files : Files + dataclass containing train_clean, train_noisy, test_clean, test_noisy + folder names (refer cli/train_config/dataset) + duration : float + expected audio duration of single audio sample for training + sampling_rate : int + desired sampling rate + batch_size : int + batch size of each batch + num_workers : int + num workers to be used while training + matching_function : + custom function for dataset processing. + + """ def __init__( self, From 74669990787f7cbbaf70897205decf8f5b324259 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 18:07:11 +0530 Subject: [PATCH 6/8] change to generic names --- enhancer/data/fileprocessor.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index f903375..d38e04f 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -7,7 +7,10 @@ from scipy.io import wavfile class ProcessorFunctions: @staticmethod - def match_vtck(clean_path,noisy_path): + def one_to_one(clean_path,noisy_path): + """ + One clean audio can have only one noisy audio file + """ matching_wavfiles = list() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -27,7 +30,10 @@ class ProcessorFunctions: return matching_wavfiles @staticmethod - def match_dns2020(clean_path,noisy_path): + def one_to_many(clean_path,noisy_path): + """ + One clean audio have multiple noisy audio files + """ matching_wavfiles = dict() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -68,9 +74,9 @@ class Fileprocessor: ): if name.lower() == "vctk": - return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck) + return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020) + return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) else: return cls(clean_dir,noisy_dir, matching_function) From 35bd3951fff19d7994a1b4f743247f32ed368d7e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 19:03:06 +0530 Subject: [PATCH 7/8] simplify matching function --- enhancer/data/dataset.py | 7 +++++-- enhancer/data/fileprocessor.py | 18 +++++++++++++----- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index fc871b8..4c485c8 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -109,8 +109,11 @@ class EnhancerDataset(TaskDataset): batch size of each batch num_workers : int num workers to be used while training - matching_function : - custom function for dataset processing. + matching_function : str + maching functions - (one_to_one,one_to_many). Default set to None. + use one_to_one mapping for datasets with one noisy file for each clean file + use one_to_many mapping for multiple noisy files for each clean file + """ diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index d38e04f..eab41a0 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -4,6 +4,8 @@ from re import S import numpy as np from scipy.io import wavfile +MATCHING_FNS = ("one_to_one","one_to_many") + class ProcessorFunctions: @staticmethod @@ -73,12 +75,18 @@ class Fileprocessor: matching_function=None ): - if name.lower() == "vctk": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) - elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) + if matching_function is None: + if name.lower() == "vctk": + return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) + elif name.lower() == "dns-2020": + return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) else: - return cls(clean_dir,noisy_dir, matching_function) + if matching_function not in MATCHING_FNS: + raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") + else: + return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) + + def prepare_matching_dict(self): From 3c9a3ab3f3bf8e07b47713e1b46f97e5374e7d14 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 08:48:14 +0530 Subject: [PATCH 8/8] relative imports --- enhancer/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index a8ae539..3da7ede 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,2 +1,3 @@ from enhancer.utils.utils import check_files -from enhancer.utils.io import Audio \ No newline at end of file +from enhancer.utils.io import Audio +from enhancer.utils.config import Files \ No newline at end of file