Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-02 08:48:43 +05:30
commit 10ec1a76c8
7 changed files with 102 additions and 29 deletions

View File

@ -4,7 +4,7 @@ from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
@ -20,14 +20,15 @@ def main(config: DictConfig):
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
loss=parameters.get("loss"), metric = parameters.get("metric"))
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="",filename="model",monitor="valid_loss",verbose=False,
mode="min",every_n_epochs=1
dirpath="",filename="model",monitor="val_loss",verbose=False,
mode=direction,every_n_epochs=1
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="valid_loss",
mode="min",
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=100,
strict=True,

View File

@ -91,7 +91,31 @@ class TaskDataset(pl.LightningDataModule):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
class EnhancerDataset(TaskDataset):
"""Dataset object for creating clean-noisy speech enhancement datasets"""
"""
Dataset object for creating clean-noisy speech enhancement datasets
paramters:
name : str
name of the dataset
root_dir : str
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset)
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
batch_size : int
batch size of each batch
num_workers : int
num workers to be used while training
matching_function : str
maching functions - (one_to_one,one_to_many). Default set to None.
use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
self,

View File

@ -4,10 +4,15 @@ from re import S
import numpy as np
from scipy.io import wavfile
MATCHING_FNS = ("one_to_one","one_to_many")
class ProcessorFunctions:
@staticmethod
def match_vtck(clean_path,noisy_path):
def one_to_one(clean_path,noisy_path):
"""
One clean audio can have only one noisy audio file
"""
matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
@ -27,7 +32,10 @@ class ProcessorFunctions:
return matching_wavfiles
@staticmethod
def match_dns2020(clean_path,noisy_path):
def one_to_many(clean_path,noisy_path):
"""
One clean audio have multiple noisy audio files
"""
matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
@ -67,12 +75,18 @@ class Fileprocessor:
matching_function=None
):
if name.lower() == "vctk":
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
if matching_function is None:
if name.lower() == "vctk":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
else:
return cls(clean_dir,noisy_dir, matching_function)
if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
def prepare_matching_dict(self):

View File

@ -1,5 +1,3 @@
from modulefinder import Module
from turtle import forward
import torch
import torch.nn as nn
@ -10,6 +8,7 @@ class mean_squared_error(nn.Module):
super().__init__()
self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module):
super().__init__()
self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
@ -45,6 +45,7 @@ class Si_SDR(nn.Module):
self.reduction = reduction
else:
raise TypeError("Invalid reduction, valid options are sum, mean, None")
self.higher_better = False
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
@ -76,6 +77,12 @@ class Avergeloss(nn.Module):
super().__init__()
self.valid_losses = nn.ModuleList()
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
if len(set(direction)) > 1:
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
self.higher_better = direction[0]
for loss in losses:
loss = self.validate_loss(loss)
self.valid_losses.append(loss())

View File

@ -1,3 +1,7 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import logging
@ -42,6 +46,33 @@ class Model(pl.LightningModule):
if self.logger:
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
self.loss = loss
self.metric = metric
@property
def loss(self):
return self._loss
@loss.setter
def loss(self,loss):
if isinstance(loss,str):
losses = [loss]
self._loss = Avergeloss(losses)
@property
def metric(self):
return self._metric
@metric.setter
def metric(self,metric):
if isinstance(metric,str):
metric = [metric]
self._metric = Avergeloss(metric)
@property
def dataset(self):
@ -55,15 +86,6 @@ class Model(pl.LightningModule):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
self.loss = self.setup_loss(self.hparams.loss)
self.metric = self.setup_loss(self.hparams.metric)
def setup_loss(self,loss):
if isinstance(loss,str):
losses = [loss]
return Avergeloss(losses)
def train_dataloader(self):
return self.dataset.train_dataloader()
@ -226,6 +248,11 @@ class Model(pl.LightningModule):
else:
return waveform
@property
def valid_monitor(self):
return "max" if self.loss.higher_better else "min"

View File

@ -1,2 +1,3 @@
from enhancer.utils.utils import check_files
from enhancer.utils.io import Audio
from enhancer.utils.config import Files

View File

@ -2,7 +2,6 @@ from dataclasses import dataclass
@dataclass
class Files:
root_dir : str
train_clean : str
train_noisy : str
test_clean : str