diff --git a/cli/train.py b/cli/train.py index 88e513a..a5c83f0 100644 --- a/cli/train.py +++ b/cli/train.py @@ -4,7 +4,7 @@ from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger - +os.environ["HYDRA_FULL_ERROR"] = "1" @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): @@ -20,14 +20,15 @@ def main(config: DictConfig): model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), loss=parameters.get("loss"), metric = parameters.get("metric")) + direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor="valid_loss",verbose=False, - mode="min",every_n_epochs=1 + dirpath="",filename="model",monitor="val_loss",verbose=False, + mode=direction,every_n_epochs=1 ) callbacks.append(checkpoint) early_stopping = EarlyStopping( - monitor="valid_loss", - mode="min", + monitor="val_loss", + mode=direction, min_delta=0.0, patience=100, strict=True, diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 5749c36..4c485c8 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -91,7 +91,31 @@ class TaskDataset(pl.LightningDataModule): return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) class EnhancerDataset(TaskDataset): - """Dataset object for creating clean-noisy speech enhancement datasets""" + """ + Dataset object for creating clean-noisy speech enhancement datasets + paramters: + name : str + name of the dataset + root_dir : str + root directory of the dataset containing clean/noisy folders + files : Files + dataclass containing train_clean, train_noisy, test_clean, test_noisy + folder names (refer cli/train_config/dataset) + duration : float + expected audio duration of single audio sample for training + sampling_rate : int + desired sampling rate + batch_size : int + batch size of each batch + num_workers : int + num workers to be used while training + matching_function : str + maching functions - (one_to_one,one_to_many). Default set to None. + use one_to_one mapping for datasets with one noisy file for each clean file + use one_to_many mapping for multiple noisy files for each clean file + + + """ def __init__( self, diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index f903375..eab41a0 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -4,10 +4,15 @@ from re import S import numpy as np from scipy.io import wavfile +MATCHING_FNS = ("one_to_one","one_to_many") + class ProcessorFunctions: @staticmethod - def match_vtck(clean_path,noisy_path): + def one_to_one(clean_path,noisy_path): + """ + One clean audio can have only one noisy audio file + """ matching_wavfiles = list() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -27,7 +32,10 @@ class ProcessorFunctions: return matching_wavfiles @staticmethod - def match_dns2020(clean_path,noisy_path): + def one_to_many(clean_path,noisy_path): + """ + One clean audio have multiple noisy audio files + """ matching_wavfiles = dict() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -67,12 +75,18 @@ class Fileprocessor: matching_function=None ): - if name.lower() == "vctk": - return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck) - elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020) + if matching_function is None: + if name.lower() == "vctk": + return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) + elif name.lower() == "dns-2020": + return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) else: - return cls(clean_dir,noisy_dir, matching_function) + if matching_function not in MATCHING_FNS: + raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") + else: + return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) + + def prepare_matching_dict(self): diff --git a/enhancer/loss.py b/enhancer/loss.py index 3bc6fa2..ef33161 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -1,5 +1,3 @@ -from modulefinder import Module -from turtle import forward import torch import torch.nn as nn @@ -10,6 +8,7 @@ class mean_squared_error(nn.Module): super().__init__() self.loss_fun = nn.MSELoss(reduction=reduction) + self.higher_better = False def forward(self,prediction:torch.Tensor, target: torch.Tensor): @@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module): super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) + self.higher_better = False def forward(self, prediction:torch.Tensor, target: torch.Tensor): @@ -45,6 +45,7 @@ class Si_SDR(nn.Module): self.reduction = reduction else: raise TypeError("Invalid reduction, valid options are sum, mean, None") + self.higher_better = False def forward(self,prediction:torch.Tensor, target:torch.Tensor): @@ -76,6 +77,12 @@ class Avergeloss(nn.Module): super().__init__() self.valid_losses = nn.ModuleList() + + direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses] + if len(set(direction)) > 1: + raise ValueError("all cost functions should be of same nature, maximize or minimize!") + + self.higher_better = direction[0] for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 8e607ed..b1bdd86 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,3 +1,7 @@ +try: + from functools import cached_property +except ImportError: + from backports.cached_property import cached_property from importlib import import_module from huggingface_hub import cached_download, hf_hub_url import logging @@ -42,7 +46,34 @@ class Model(pl.LightningModule): if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") - + self.loss = loss + self.metric = metric + + @property + def loss(self): + return self._loss + + @loss.setter + def loss(self,loss): + + if isinstance(loss,str): + losses = [loss] + + self._loss = Avergeloss(losses) + + @property + def metric(self): + return self._metric + + @metric.setter + def metric(self,metric): + + if isinstance(metric,str): + metric = [metric] + + self._metric = Avergeloss(metric) + + @property def dataset(self): return self._dataset @@ -55,16 +86,7 @@ class Model(pl.LightningModule): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - self.loss = self.setup_loss(self.hparams.loss) - self.metric = self.setup_loss(self.hparams.metric) - - def setup_loss(self,loss): - - if isinstance(loss,str): - losses = [loss] - - return Avergeloss(losses) - + def train_dataloader(self): return self.dataset.train_dataloader() @@ -224,7 +246,12 @@ class Model(pl.LightningModule): Inference.write_output(waveform,audio,model_sampling_rate) else: - return waveform + return waveform + + @property + def valid_monitor(self): + + return "max" if self.loss.higher_better else "min" diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index a8ae539..3da7ede 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,2 +1,3 @@ from enhancer.utils.utils import check_files -from enhancer.utils.io import Audio \ No newline at end of file +from enhancer.utils.io import Audio +from enhancer.utils.config import Files \ No newline at end of file diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index e9af6a0..1bbc51d 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -2,7 +2,6 @@ from dataclasses import dataclass @dataclass class Files: - root_dir : str train_clean : str train_noisy : str test_clean : str