Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
315d646347
|
|
@ -1,8 +1,8 @@
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://user-images.githubusercontent.com/25312635/195507951-fe64657c-9114-4d78-b04e-444e6d5bbcc4.png" />
|
<img src="https://user-images.githubusercontent.com/25312635/195514652-e4526cd1-1177-48e9-a80d-c8bfdb95d35f.png" />
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training .
|
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable model training.
|
||||||
|
|
||||||
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
||||||
## Key features :key:
|
## Key features :key:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
|
@ -45,15 +45,16 @@ def main(config: DictConfig):
|
||||||
every_n_epochs=1,
|
every_n_epochs=1,
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
# early_stopping = EarlyStopping(
|
if parameters.get("Early_stop", False):
|
||||||
# monitor="val_loss",
|
early_stopping = EarlyStopping(
|
||||||
# mode=direction,
|
monitor="val_loss",
|
||||||
# min_delta=0.0,
|
mode=direction,
|
||||||
# patience=parameters.get("EarlyStopping_patience", 10),
|
min_delta=0.0,
|
||||||
# strict=True,
|
patience=parameters.get("EarlyStopping_patience", 10),
|
||||||
# verbose=False,
|
strict=True,
|
||||||
# )
|
verbose=False,
|
||||||
# callbacks.append(early_stopping)
|
)
|
||||||
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
def configure_optimizer(self):
|
def configure_optimizer(self):
|
||||||
optimizer = instantiate(
|
optimizer = instantiate(
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,20 @@ class Model(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
|
print(
|
||||||
|
"Total train duration",
|
||||||
|
self.dataset.train_dataloader().__len__()
|
||||||
|
* self.dataset.duration
|
||||||
|
/ 60,
|
||||||
|
"minutes",
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Total validation duration",
|
||||||
|
self.dataset.val_dataloader().__len__()
|
||||||
|
* self.dataset.duration
|
||||||
|
/ 60,
|
||||||
|
"minutes",
|
||||||
|
)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue