From 37fe86063d52b582e162d5cff92edbdc282a1991 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 10 Oct 2022 12:38:51 +0530 Subject: [PATCH 1/5] add model testing --- enhancer/cli/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index a32c41f..7b245d8 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -75,6 +75,7 @@ def main(config: DictConfig): trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) + trainer.test(model) logger.experiment.log_artifact( logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" From 3e654d10a7dab7ebcfe823fd5ddf64fda1715993 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 10 Oct 2022 12:45:23 +0530 Subject: [PATCH 2/5] add test dataloader --- enhancer/data/dataset.py | 77 ++++++++++++++++++++++++++-------- enhancer/data/fileprocessor.py | 4 +- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 95c73a1..d2b7526 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -5,6 +5,7 @@ from typing import Optional import pytorch_lightning as pl import torch.nn.functional as F +from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, Dataset, IterableDataset from enhancer.data.fileprocessor import Fileprocessor @@ -36,12 +37,24 @@ class ValidDataset(Dataset): return self.dataset.val__len__() +class TestDataset(Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, idx): + return self.dataset.test__getitem__(idx) + + def __len__(self): + return self.dataset.test__len__() + + class TaskDataset(pl.LightningDataModule): def __init__( self, name: str, root_dir: str, files: Files, + valid_size: float = 0.20, duration: float = 1.0, sampling_rate: int = 48000, matching_function=None, @@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers + if valid_size > 0.0: + self.valid_size = valid_size + else: + raise ValueError("valid_size must be greater than 0") def setup(self, stage: Optional[str] = None): + """ + prepare train/validation/test data splits + """ if stage in ("fit", None): @@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule): fp = Fileprocessor.from_name( self.name, train_clean, train_noisy, self.matching_function ) - self.train_data = fp.prepare_matching_dict() - - val_clean = os.path.join(self.root_dir, self.files.test_clean) - val_noisy = os.path.join(self.root_dir, self.files.test_noisy) - fp = Fileprocessor.from_name( - self.name, val_clean, val_noisy, self.matching_function + train_data = fp.prepare_matching_dict() + self.train_data, self.val_data = train_test_split( + train_data, test_size=0.20, shuffle=True, random_state=42 ) - val_data = fp.prepare_matching_dict() - for item in val_data: - clean, noisy, total_dur = item.values() - if total_dur < self.duration: - continue - num_segments = round(total_dur / self.duration) - for index in range(num_segments): - start_time = index * self.duration - self._validation.append( - ({"clean": clean, "noisy": noisy}, start_time) - ) + self._validation = self.prepare_mapstype(self.val_data) + + test_clean = os.path.join(self.root_dir, self.files.test_clean) + test_noisy = os.path.join(self.root_dir, self.files.test_noisy) + fp = Fileprocessor.from_name( + self.name, test_clean, test_noisy, self.matching_function + ) + test_data = fp.prepare_matching_dict() + self._test = self.prepare_mapstype(test_data) + + def prepare_mapstype(self, data): + + metadata = [] + for item in data: + clean, noisy, total_dur = item.values() + if total_dur < self.duration: + continue + num_segments = round(total_dur / self.duration) + for index in range(num_segments): + start_time = index * self.duration + metadata.append(({"clean": clean, "noisy": noisy}, start_time)) + return metadata def train_dataloader(self): return DataLoader( @@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule): num_workers=self.num_workers, ) + def test_dataloader(self): + return DataLoader( + TestDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + class EnhancerDataset(TaskDataset): """ @@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset): name: str, root_dir: str, files: Files, + valid_size=0.2, duration=1.0, sampling_rate=48000, matching_function=None, @@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset): name=name, root_dir=root_dir, files=files, + valid_size=valid_size, sampling_rate=sampling_rate, duration=duration, matching_function=matching_function, @@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset): def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) + def test__getitem__(self, idx): + return self.prepare_segment(*self._test[idx]) + def prepare_segment(self, file_dict: dict, start_time: float): clean_segment = self.audio( @@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset): def val__len__(self): return len(self._validation) + + def test__len__(self): + return len(self._test) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 03afc73..e718f15 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -55,7 +55,7 @@ class ProcessorFunctions: One clean audio have multiple noisy audio files """ - matching_wavfiles = dict() + matching_wavfiles = list() clean_filenames = [ file.split("/")[-1] for file in glob.glob(os.path.join(clean_path, "*.wav")) @@ -73,7 +73,7 @@ class ProcessorFunctions: if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( sr_clean == sr_noisy ): - matching_wavfiles.update( + matching_wavfiles.append( { "clean": os.path.join(clean_path, clean_file), "noisy": noisy_file, From 5945ddccaaab990f2df2bd06cf42d75c7e3b56ca Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 10 Oct 2022 12:46:36 +0530 Subject: [PATCH 3/5] add pesq/stoi --- enhancer/loss.py | 49 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index db1d222..9ef90d2 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -1,5 +1,9 @@ +import logging + import torch import torch.nn as nn +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility class mean_squared_error(nn.Module): @@ -12,6 +16,7 @@ class mean_squared_error(nn.Module): self.loss_fun = nn.MSELoss(reduction=reduction) self.higher_better = False + self.name = "mse" def forward(self, prediction: torch.Tensor, target: torch.Tensor): @@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module): self.loss_fun = nn.L1Loss(reduction=reduction) self.higher_better = False + self.name = "mae" def forward(self, prediction: torch.Tensor, target: torch.Tensor): @@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module): return self.loss_fun(prediction, target) -class Si_SDR(nn.Module): +class Si_SDR: """ SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) """ def __init__(self, reduction: str = "mean"): - super().__init__() if reduction in ["sum", "mean", None]: self.reduction = reduction else: @@ -60,8 +65,9 @@ class Si_SDR(nn.Module): "Invalid reduction, valid options are sum, mean, None" ) self.higher_better = False + self.name = "Si-SDR" - def forward(self, prediction: torch.Tensor, target: torch.Tensor): + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( @@ -90,7 +96,40 @@ class Si_SDR(nn.Module): return si_sdr -class Avergeloss(nn.Module): +class Stoi: + """ + STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. + Note that input will be moved to cpu to perform the metric calculation. + parameters: + sr: int + sampling rate + """ + + def __init__(self, sr: int): + self.sr = sr + self.stoi = ShortTimeObjectiveIntelligibility(fs=sr) + self.name = "stoi" + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + + return self.stoi(prediction, target) + + +class Pesq: + def __init__(self, sr: int, mode="nb"): + + self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) + self.name = "pesq" + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor): + try: + return self.pesq(prediction, target) + except Exception as e: + logging.warning(f"{e} error occured while calculating PESQ") + return 0.0 + + +class LossWrapper(nn.Module): """ Combine multiple metics of same nature. for example, ["mea","mae"] @@ -137,4 +176,6 @@ LOSS_MAP = { "mae": mean_absolute_error, "mse": mean_squared_error, "SI-SDR": Si_SDR, + "pesq": Pesq, + "stoi": Stoi, } From 1aca956ed44606b9e609e16120144287ceec938e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 10 Oct 2022 12:46:59 +0530 Subject: [PATCH 4/5] update requirements --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 3762fd2..95f145d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,9 @@ joblib>=1.2.0 librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 +pesq==0.0.4 protobuf>=3.19.6 +pystoi==0.3.3 pytest-lazy-fixture>=0.6.3 pytorch-lightning>=1.7.7 scikit-learn>=1.1.2 From 2587d5b8675edf5239ebcf14265e80958e6764e5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 10 Oct 2022 12:47:24 +0530 Subject: [PATCH 5/5] add test step --- enhancer/models/model.py | 48 ++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 7ff15e4..07564cd 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -13,7 +13,7 @@ from torch.optim import Adam from enhancer.data.dataset import EnhancerDataset from enhancer.inference import Inference -from enhancer.loss import Avergeloss +from enhancer.loss import LOSS_MAP, LossWrapper from enhancer.version import __version__ CACHE_DIR = "" @@ -76,7 +76,7 @@ class Model(pl.LightningModule): if isinstance(loss, str): loss = [loss] - self._loss = Avergeloss(loss) + self._loss = LossWrapper(loss) @property def metric(self): @@ -84,11 +84,21 @@ class Model(pl.LightningModule): @metric.setter def metric(self, metric): - + self._metric = [] if isinstance(metric, str): metric = [metric] - self._metric = Avergeloss(metric) + for func in metric: + if func in LOSS_MAP.keys(): + if func in ("pesq", "stoi"): + self._metric.append( + LOSS_MAP[func](self.hparams.sampling_rate) + ) + else: + self._metric.append(LOSS_MAP[func]()) + + else: + raise ValueError(f"Invalid metrics {func}") @property def dataset(self): @@ -109,6 +119,9 @@ class Model(pl.LightningModule): def val_dataloader(self): return self.dataset.val_dataloader() + def test_dataloader(self): + return self.dataset.test_dataloader() + def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.lr) @@ -140,9 +153,7 @@ class Model(pl.LightningModule): target = batch["clean"] prediction = self(mixed_waveform) - metric_val = self.metric(prediction, target) loss_val = self.loss(prediction, target) - self.log("val_metric", metric_val.item()) self.log("val_loss", loss_val.item()) if ( @@ -156,15 +167,28 @@ class Model(pl.LightningModule): value=loss_val.item(), step=self.global_step, ) - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="val_metric", - value=metric_val.item(), - step=self.global_step, - ) return {"loss": loss_val} + def test_step(self, batch, batch_idx): + + metric_dict = {} + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + + for metric in self.metric: + value = metric(target, prediction) + metric_dict[metric.name] = value + + self.logger.experiment.log_metrics( + run_id=self.logger.run_id, + metrics=metric_dict, + step=self.global_step, + ) + + return metric_dict + def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = {