Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-14 11:32:59 +05:30
commit 315d646347
3 changed files with 27 additions and 12 deletions

View File

@ -1,8 +1,8 @@
<p align="center">
<img src="https://user-images.githubusercontent.com/25312635/195507951-fe64657c-9114-4d78-b04e-444e6d5bbcc4.png" />
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
</p>
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training .
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
## Key features :key:

View File

@ -4,7 +4,7 @@ from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
@ -45,15 +45,16 @@ def main(config: DictConfig):
every_n_epochs=1,
)
callbacks.append(checkpoint)
# early_stopping = EarlyStopping(
# monitor="val_loss",
# mode=direction,
# min_delta=0.0,
# patience=parameters.get("EarlyStopping_patience", 10),
# strict=True,
# verbose=False,
# )
# callbacks.append(early_stopping)
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(

View File

@ -113,6 +113,20 @@ class Model(pl.LightningModule):
if stage == "fit":
torch.cuda.empty_cache()
self.dataset.setup(stage)
print(
"Total train duration",
self.dataset.train_dataloader().__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total validation duration",
self.dataset.val_dataloader().__len__()
* self.dataset.duration
/ 60,
"minutes",
)
self.dataset.model = self
def train_dataloader(self):