Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
10ec1a76c8
11
cli/train.py
11
cli/train.py
|
|
@ -4,7 +4,7 @@ from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
@ -20,14 +20,15 @@ def main(config: DictConfig):
|
||||||
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
||||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||||
|
|
||||||
|
direction = model.valid_monitor
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="",filename="model",monitor="valid_loss",verbose=False,
|
dirpath="",filename="model",monitor="val_loss",verbose=False,
|
||||||
mode="min",every_n_epochs=1
|
mode=direction,every_n_epochs=1
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor="valid_loss",
|
monitor="val_loss",
|
||||||
mode="min",
|
mode=direction,
|
||||||
min_delta=0.0,
|
min_delta=0.0,
|
||||||
patience=100,
|
patience=100,
|
||||||
strict=True,
|
strict=True,
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,31 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class EnhancerDataset(TaskDataset):
|
||||||
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
"""
|
||||||
|
Dataset object for creating clean-noisy speech enhancement datasets
|
||||||
|
paramters:
|
||||||
|
name : str
|
||||||
|
name of the dataset
|
||||||
|
root_dir : str
|
||||||
|
root directory of the dataset containing clean/noisy folders
|
||||||
|
files : Files
|
||||||
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
|
folder names (refer cli/train_config/dataset)
|
||||||
|
duration : float
|
||||||
|
expected audio duration of single audio sample for training
|
||||||
|
sampling_rate : int
|
||||||
|
desired sampling rate
|
||||||
|
batch_size : int
|
||||||
|
batch size of each batch
|
||||||
|
num_workers : int
|
||||||
|
num workers to be used while training
|
||||||
|
matching_function : str
|
||||||
|
maching functions - (one_to_one,one_to_many). Default set to None.
|
||||||
|
use one_to_one mapping for datasets with one noisy file for each clean file
|
||||||
|
use one_to_many mapping for multiple noisy files for each clean file
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,15 @@ from re import S
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
MATCHING_FNS = ("one_to_one","one_to_many")
|
||||||
|
|
||||||
class ProcessorFunctions:
|
class ProcessorFunctions:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def match_vtck(clean_path,noisy_path):
|
def one_to_one(clean_path,noisy_path):
|
||||||
|
"""
|
||||||
|
One clean audio can have only one noisy audio file
|
||||||
|
"""
|
||||||
|
|
||||||
matching_wavfiles = list()
|
matching_wavfiles = list()
|
||||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||||
|
|
@ -27,7 +32,10 @@ class ProcessorFunctions:
|
||||||
return matching_wavfiles
|
return matching_wavfiles
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def match_dns2020(clean_path,noisy_path):
|
def one_to_many(clean_path,noisy_path):
|
||||||
|
"""
|
||||||
|
One clean audio have multiple noisy audio files
|
||||||
|
"""
|
||||||
|
|
||||||
matching_wavfiles = dict()
|
matching_wavfiles = dict()
|
||||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||||
|
|
@ -67,12 +75,18 @@ class Fileprocessor:
|
||||||
matching_function=None
|
matching_function=None
|
||||||
):
|
):
|
||||||
|
|
||||||
if name.lower() == "vctk":
|
if matching_function is None:
|
||||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
|
if name.lower() == "vctk":
|
||||||
elif name.lower() == "dns-2020":
|
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
|
||||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
|
elif name.lower() == "dns-2020":
|
||||||
|
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
|
||||||
else:
|
else:
|
||||||
return cls(clean_dir,noisy_dir, matching_function)
|
if matching_function not in MATCHING_FNS:
|
||||||
|
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
|
||||||
|
else:
|
||||||
|
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_matching_dict(self):
|
def prepare_matching_dict(self):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
from modulefinder import Module
|
|
||||||
from turtle import forward
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
@ -10,6 +8,7 @@ class mean_squared_error(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||||
|
self.higher_better = False
|
||||||
|
|
||||||
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
|
def forward(self,prediction:torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -25,6 +24,7 @@ class mean_absolute_error(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||||
|
self.higher_better = False
|
||||||
|
|
||||||
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction:torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -45,6 +45,7 @@ class Si_SDR(nn.Module):
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
raise TypeError("Invalid reduction, valid options are sum, mean, None")
|
||||||
|
self.higher_better = False
|
||||||
|
|
||||||
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
def forward(self,prediction:torch.Tensor, target:torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -76,6 +77,12 @@ class Avergeloss(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.valid_losses = nn.ModuleList()
|
self.valid_losses = nn.ModuleList()
|
||||||
|
|
||||||
|
direction = [getattr(LOSS_MAP[loss](),"higher_better") for loss in losses]
|
||||||
|
if len(set(direction)) > 1:
|
||||||
|
raise ValueError("all cost functions should be of same nature, maximize or minimize!")
|
||||||
|
|
||||||
|
self.higher_better = direction[0]
|
||||||
for loss in losses:
|
for loss in losses:
|
||||||
loss = self.validate_loss(loss)
|
loss = self.validate_loss(loss)
|
||||||
self.valid_losses.append(loss())
|
self.valid_losses.append(loss())
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,7 @@
|
||||||
|
try:
|
||||||
|
from functools import cached_property
|
||||||
|
except ImportError:
|
||||||
|
from backports.cached_property import cached_property
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -42,7 +46,34 @@ class Model(pl.LightningModule):
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||||
|
|
||||||
|
self.loss = loss
|
||||||
|
self.metric = metric
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss(self):
|
||||||
|
return self._loss
|
||||||
|
|
||||||
|
@loss.setter
|
||||||
|
def loss(self,loss):
|
||||||
|
|
||||||
|
if isinstance(loss,str):
|
||||||
|
losses = [loss]
|
||||||
|
|
||||||
|
self._loss = Avergeloss(losses)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric(self):
|
||||||
|
return self._metric
|
||||||
|
|
||||||
|
@metric.setter
|
||||||
|
def metric(self,metric):
|
||||||
|
|
||||||
|
if isinstance(metric,str):
|
||||||
|
metric = [metric]
|
||||||
|
|
||||||
|
self._metric = Avergeloss(metric)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
return self._dataset
|
return self._dataset
|
||||||
|
|
@ -55,16 +86,7 @@ class Model(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
self.loss = self.setup_loss(self.hparams.loss)
|
|
||||||
self.metric = self.setup_loss(self.hparams.metric)
|
|
||||||
|
|
||||||
def setup_loss(self,loss):
|
|
||||||
|
|
||||||
if isinstance(loss,str):
|
|
||||||
losses = [loss]
|
|
||||||
|
|
||||||
return Avergeloss(losses)
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
||||||
|
|
@ -224,7 +246,12 @@ class Model(pl.LightningModule):
|
||||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid_monitor(self):
|
||||||
|
|
||||||
|
return "max" if self.loss.higher_better else "min"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from enhancer.utils.utils import check_files
|
from enhancer.utils.utils import check_files
|
||||||
from enhancer.utils.io import Audio
|
from enhancer.utils.io import Audio
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Files:
|
class Files:
|
||||||
root_dir : str
|
|
||||||
train_clean : str
|
train_clean : str
|
||||||
train_noisy : str
|
train_noisy : str
|
||||||
test_clean : str
|
test_clean : str
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue