diff --git a/README.md b/README.md index 586df14..13b8e14 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@

- +

-mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . +mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training. | **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()** ## Key features :key: diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 49f7b3b..08f4d3e 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,7 +4,7 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -45,15 +45,16 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) - # early_stopping = EarlyStopping( - # monitor="val_loss", - # mode=direction, - # min_delta=0.0, - # patience=parameters.get("EarlyStopping_patience", 10), - # strict=True, - # verbose=False, - # ) - # callbacks.append(early_stopping) + if parameters.get("Early_stop", False): + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) def configure_optimizer(self): optimizer = instantiate( diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 4f055b4..714b0e5 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -113,6 +113,20 @@ class Model(pl.LightningModule): if stage == "fit": torch.cuda.empty_cache() self.dataset.setup(stage) + print( + "Total train duration", + self.dataset.train_dataloader().__len__() + * self.dataset.duration + / 60, + "minutes", + ) + print( + "Total validation duration", + self.dataset.val_dataloader().__len__() + * self.dataset.duration + / 60, + "minutes", + ) self.dataset.model = self def train_dataloader(self):