diff --git a/README.md b/README.md
index 586df14..13b8e14 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
-
+
-mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training .
+mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
## Key features :key:
diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py
index 49f7b3b..08f4d3e 100644
--- a/enhancer/cli/train.py
+++ b/enhancer/cli/train.py
@@ -4,7 +4,7 @@ from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
-from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
@@ -45,15 +45,16 @@ def main(config: DictConfig):
every_n_epochs=1,
)
callbacks.append(checkpoint)
- # early_stopping = EarlyStopping(
- # monitor="val_loss",
- # mode=direction,
- # min_delta=0.0,
- # patience=parameters.get("EarlyStopping_patience", 10),
- # strict=True,
- # verbose=False,
- # )
- # callbacks.append(early_stopping)
+ if parameters.get("Early_stop", False):
+ early_stopping = EarlyStopping(
+ monitor="val_loss",
+ mode=direction,
+ min_delta=0.0,
+ patience=parameters.get("EarlyStopping_patience", 10),
+ strict=True,
+ verbose=False,
+ )
+ callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(
diff --git a/enhancer/models/model.py b/enhancer/models/model.py
index 4f055b4..714b0e5 100644
--- a/enhancer/models/model.py
+++ b/enhancer/models/model.py
@@ -113,6 +113,20 @@ class Model(pl.LightningModule):
if stage == "fit":
torch.cuda.empty_cache()
self.dataset.setup(stage)
+ print(
+ "Total train duration",
+ self.dataset.train_dataloader().__len__()
+ * self.dataset.duration
+ / 60,
+ "minutes",
+ )
+ print(
+ "Total validation duration",
+ self.dataset.val_dataloader().__len__()
+ * self.dataset.duration
+ / 60,
+ "minutes",
+ )
self.dataset.model = self
def train_dataloader(self):