From 891446f7db2fd20f0e2e86ee5394dc30d6436fd2 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 13 Oct 2022 11:34:28 +0530 Subject: [PATCH 1/3] add tagline --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 586df14..13b8e14 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@

- +

-mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . +mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training. | **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()** ## Key features :key: From 204de08a9ac065f32997aab12c34738e0debced2 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 14 Oct 2022 10:52:56 +0530 Subject: [PATCH 2/3] add early stopping --- enhancer/cli/train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 49f7b3b..08f4d3e 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,7 +4,7 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -45,15 +45,16 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) - # early_stopping = EarlyStopping( - # monitor="val_loss", - # mode=direction, - # min_delta=0.0, - # patience=parameters.get("EarlyStopping_patience", 10), - # strict=True, - # verbose=False, - # ) - # callbacks.append(early_stopping) + if parameters.get("Early_stop", False): + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) def configure_optimizer(self): optimizer = instantiate( From 6a3c67fc13c8d11ef00e75eb79e47e79d7eb9bf9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 14 Oct 2022 11:32:18 +0530 Subject: [PATCH 3/3] print train/val duration --- enhancer/models/model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 4f055b4..714b0e5 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -113,6 +113,20 @@ class Model(pl.LightningModule): if stage == "fit": torch.cuda.empty_cache() self.dataset.setup(stage) + print( + "Total train duration", + self.dataset.train_dataloader().__len__() + * self.dataset.duration + / 60, + "minutes", + ) + print( + "Total validation duration", + self.dataset.val_dataloader().__len__() + * self.dataset.duration + / 60, + "minutes", + ) self.dataset.model = self def train_dataloader(self):