Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
10ec1a76c8
11
cli/train.py
11
cli/train.py
|
|
@ -4,7 +4,7 @@ from hydra.utils import instantiate
|
|||
from omegaconf import DictConfig
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
|
||||
@hydra.main(config_path="train_config",config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
|
@ -20,14 +20,15 @@ def main(config: DictConfig):
|
|||
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||
|
||||
direction = model.valid_monitor
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="",filename="model",monitor="valid_loss",verbose=False,
|
||||
mode="min",every_n_epochs=1
|
||||
dirpath="",filename="model",monitor="val_loss",verbose=False,
|
||||
mode=direction,every_n_epochs=1
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="valid_loss",
|
||||
mode="min",
|
||||
monitor="val_loss",
|
||||
mode=direction,
|
||||
min_delta=0.0,
|
||||
patience=100,
|
||||
strict=True,
|
||||
|
|
|
|||
|
|
@ -91,7 +91,31 @@ class TaskDataset(pl.LightningDataModule):
|
|||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
||||
"""
|
||||
Dataset object for creating clean-noisy speech enhancement datasets
|
||||
paramters:
|
||||
name : str
|
||||
name of the dataset
|
||||
root_dir : str
|
||||
root directory of the dataset containing clean/noisy folders
|
||||
files : Files
|
||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||
folder names (refer cli/train_config/dataset)
|
||||
duration : float
|
||||
expected audio duration of single audio sample for training
|
||||
sampling_rate : int
|
||||
desired sampling rate
|
||||
batch_size : int
|
||||
batch size of each batch
|
||||
num_workers : int
|
||||
num workers to be used while training
|
||||
matching_function : str
|
||||
maching functions - (one_to_one,one_to_many). Default set to None.
|
||||
use one_to_one mapping for datasets with one noisy file for each clean file
|
||||
use one_to_many mapping for multiple noisy files for each clean file
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,15 @@ from re import S
|
|||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
MATCHING_FNS = ("one_to_one","one_to_many")
|
||||
|
||||
class ProcessorFunctions:
|
||||
|
||||
@staticmethod
|
||||
def match_vtck(clean_path,noisy_path):
|
||||
def one_to_one(clean_path,noisy_path):
|
||||
"""
|
||||
One clean audio can have only one noisy audio file
|
||||
"""
|
||||
|
||||
matching_wavfiles = list()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
|
|
@ -27,7 +32,10 @@ class ProcessorFunctions:
|
|||
return matching_wavfiles
|
||||
|
||||
@staticmethod
|
||||
def match_dns2020(clean_path,noisy_path):
|
||||
def one_to_many(clean_path,noisy_path):
|
||||
"""
|
||||
One clean audio have multiple noisy audio files
|
||||
"""
|
||||
|
||||
matching_wavfiles = dict()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
|
|
@ -67,12 +75,18 @@ class Fileprocessor:
|
|||
matching_function=None
|
||||
):
|
||||
|
||||
if name.lower() == "vctk":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
|
||||
if matching_function is None:
|
||||
if name.lower() == "vctk":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
|
||||
else:
|
||||
return cls(clean_dir,noisy_dir, matching_function)
|
||||
if matching_function not in MATCHING_FNS:
|
||||
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
|
||||
else:
|
||||
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
|
||||
|
||||
|
||||
|
||||
def prepare_matching_dict(self):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from modulefinder import Module
|
||||
from turtle import forward
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
@ -10,6 +8,7 @@ class mean_squared_error(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||
self.higher_better = False
|
||||
|
||||
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
|
||||
|
||||
|
|
@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||
self.higher_better = False
|
||||
|
||||
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
|
||||
|
||||
|
|
@ -45,6 +45,7 @@ class Si_SDR(nn.Module):
|
|||
self.reduction = reduction
|
||||
else:
|
||||
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
||||
self.higher_better = False
|
||||
|
||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
||||
|
||||
|
|
@ -76,6 +77,12 @@ class Avergeloss(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.valid_losses = nn.ModuleList()
|
||||
|
||||
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
|
||||
if len(set(direction)) > 1:
|
||||
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
|
||||
|
||||
self.higher_better = direction[0]
|
||||
for loss in losses:
|
||||
loss = self.validate_loss(loss)
|
||||
self.valid_losses.append(loss())
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
from backports.cached_property import cached_property
|
||||
from importlib import import_module
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
import logging
|
||||
|
|
@ -42,6 +46,33 @@ class Model(pl.LightningModule):
|
|||
if self.logger:
|
||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
return self._loss
|
||||
|
||||
@loss.setter
|
||||
def loss(self,loss):
|
||||
|
||||
if isinstance(loss,str):
|
||||
losses = [loss]
|
||||
|
||||
self._loss = Avergeloss(losses)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self._metric
|
||||
|
||||
@metric.setter
|
||||
def metric(self,metric):
|
||||
|
||||
if isinstance(metric,str):
|
||||
metric = [metric]
|
||||
|
||||
self._metric = Avergeloss(metric)
|
||||
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
|
@ -55,15 +86,6 @@ class Model(pl.LightningModule):
|
|||
if stage == "fit":
|
||||
self.dataset.setup(stage)
|
||||
self.dataset.model = self
|
||||
self.loss = self.setup_loss(self.hparams.loss)
|
||||
self.metric = self.setup_loss(self.hparams.metric)
|
||||
|
||||
def setup_loss(self,loss):
|
||||
|
||||
if isinstance(loss,str):
|
||||
losses = [loss]
|
||||
|
||||
return Avergeloss(losses)
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataset.train_dataloader()
|
||||
|
|
@ -226,6 +248,11 @@ class Model(pl.LightningModule):
|
|||
else:
|
||||
return waveform
|
||||
|
||||
@property
|
||||
def valid_monitor(self):
|
||||
|
||||
return "max" if self.loss.higher_better else "min"
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from enhancer.utils.utils import check_files
|
||||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.config import Files
|
||||
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||
|
||||
@dataclass
|
||||
class Files:
|
||||
root_dir : str
|
||||
train_clean : str
|
||||
train_noisy : str
|
||||
test_clean : str
|
||||
|
|
|
|||
Loading…
Reference in New Issue