Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-14 11:32:59 +05:30
commit 315d646347
3 changed files with 27 additions and 12 deletions

View File

@ -1,8 +1,8 @@
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/25312635/195507951-fe64657c-9114-4d78-b04e-444e6d5bbcc4.png" /> <img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
</p> </p>
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()** | **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
## Key features :key: ## Key features :key:

View File

@ -4,7 +4,7 @@ from types import MethodType
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
@ -45,15 +45,16 @@ def main(config: DictConfig):
every_n_epochs=1, every_n_epochs=1,
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
# early_stopping = EarlyStopping( if parameters.get("Early_stop", False):
# monitor="val_loss", early_stopping = EarlyStopping(
# mode=direction, monitor="val_loss",
# min_delta=0.0, mode=direction,
# patience=parameters.get("EarlyStopping_patience", 10), min_delta=0.0,
# strict=True, patience=parameters.get("EarlyStopping_patience", 10),
# verbose=False, strict=True,
# ) verbose=False,
# callbacks.append(early_stopping) )
callbacks.append(early_stopping)
def configure_optimizer(self): def configure_optimizer(self):
optimizer = instantiate( optimizer = instantiate(

View File

@ -113,6 +113,20 @@ class Model(pl.LightningModule):
if stage == "fit": if stage == "fit":
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.dataset.setup(stage) self.dataset.setup(stage)
print(
"Total train duration",
self.dataset.train_dataloader().__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total validation duration",
self.dataset.val_dataloader().__len__()
* self.dataset.duration
/ 60,
"minutes",
)
self.dataset.model = self self.dataset.model = self
def train_dataloader(self): def train_dataloader(self):