Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-10 12:48:11 +05:30
commit 5d8f49d78e
6 changed files with 146 additions and 35 deletions

View File

@ -75,6 +75,7 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"

View File

@ -5,6 +5,7 @@ from typing import Optional
import pytorch_lightning as pl
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, IterableDataset
from enhancer.data.fileprocessor import Fileprocessor
@ -36,12 +37,24 @@ class ValidDataset(Dataset):
return self.dataset.val__len__()
class TestDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.test__getitem__(idx)
def __len__(self):
return self.dataset.test__len__()
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name: str,
root_dir: str,
files: Files,
valid_size: float = 0.20,
duration: float = 1.0,
sampling_rate: int = 48000,
matching_function=None,
@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
if valid_size > 0.0:
self.valid_size = valid_size
else:
raise ValueError("valid_size must be greater than 0")
def setup(self, stage: Optional[str] = None):
"""
prepare train/validation/test data splits
"""
if stage in ("fit", None):
@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule):
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, val_clean, val_noisy, self.matching_function
train_data = fp.prepare_matching_dict()
self.train_data, self.val_data = train_test_split(
train_data, test_size=0.20, shuffle=True, random_state=42
)
val_data = fp.prepare_matching_dict()
for item in val_data:
self._validation = self.prepare_mapstype(self.val_data)
test_clean = os.path.join(self.root_dir, self.files.test_clean)
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, test_clean, test_noisy, self.matching_function
)
test_data = fp.prepare_matching_dict()
self._test = self.prepare_mapstype(test_data)
def prepare_mapstype(self, data):
metadata = []
for item in data:
clean, noisy, total_dur = item.values()
if total_dur < self.duration:
continue
num_segments = round(total_dur / self.duration)
for index in range(num_segments):
start_time = index * self.duration
self._validation.append(
({"clean": clean, "noisy": noisy}, start_time)
)
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
return metadata
def train_dataloader(self):
return DataLoader(
@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule):
num_workers=self.num_workers,
)
def test_dataloader(self):
return DataLoader(
TestDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset):
"""
@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset):
name: str,
root_dir: str,
files: Files,
valid_size=0.2,
duration=1.0,
sampling_rate=48000,
matching_function=None,
@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset):
name=name,
root_dir=root_dir,
files=files,
valid_size=valid_size,
sampling_rate=sampling_rate,
duration=duration,
matching_function=matching_function,
@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset):
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
def test__getitem__(self, idx):
return self.prepare_segment(*self._test[idx])
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(
@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset):
def val__len__(self):
return len(self._validation)
def test__len__(self):
return len(self._test)

View File

@ -55,7 +55,7 @@ class ProcessorFunctions:
One clean audio have multiple noisy audio files
"""
matching_wavfiles = dict()
matching_wavfiles = list()
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
@ -73,7 +73,7 @@ class ProcessorFunctions:
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.update(
matching_wavfiles.append(
{
"clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,

View File

@ -1,5 +1,9 @@
import logging
import torch
import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
class mean_squared_error(nn.Module):
@ -12,6 +16,7 @@ class mean_squared_error(nn.Module):
self.loss_fun = nn.MSELoss(reduction=reduction)
self.higher_better = False
self.name = "mse"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module):
self.loss_fun = nn.L1Loss(reduction=reduction)
self.higher_better = False
self.name = "mae"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module):
return self.loss_fun(prediction, target)
class Si_SDR(nn.Module):
class Si_SDR:
"""
SI-SDR metric based on SDR HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
"""
def __init__(self, reduction: str = "mean"):
super().__init__()
if reduction in ["sum", "mean", None]:
self.reduction = reduction
else:
@ -60,8 +65,9 @@ class Si_SDR(nn.Module):
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = False
self.name = "Si-SDR"
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
@ -90,7 +96,40 @@ class Si_SDR(nn.Module):
return si_sdr
class Avergeloss(nn.Module):
class Stoi:
"""
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to cpu to perform the metric calculation.
parameters:
sr: int
sampling rate
"""
def __init__(self, sr: int):
self.sr = sr
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
self.name = "stoi"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
return self.stoi(prediction, target)
class Pesq:
def __init__(self, sr: int, mode="nb"):
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
self.name = "pesq"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
try:
return self.pesq(prediction, target)
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return 0.0
class LossWrapper(nn.Module):
"""
Combine multiple metics of same nature.
for example, ["mea","mae"]
@ -137,4 +176,6 @@ LOSS_MAP = {
"mae": mean_absolute_error,
"mse": mean_squared_error,
"SI-SDR": Si_SDR,
"pesq": Pesq,
"stoi": Stoi,
}

View File

@ -13,7 +13,7 @@ from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset
from enhancer.inference import Inference
from enhancer.loss import Avergeloss
from enhancer.loss import LOSS_MAP, LossWrapper
from enhancer.version import __version__
CACHE_DIR = ""
@ -76,7 +76,7 @@ class Model(pl.LightningModule):
if isinstance(loss, str):
loss = [loss]
self._loss = Avergeloss(loss)
self._loss = LossWrapper(loss)
@property
def metric(self):
@ -84,11 +84,21 @@ class Model(pl.LightningModule):
@metric.setter
def metric(self, metric):
self._metric = []
if isinstance(metric, str):
metric = [metric]
self._metric = Avergeloss(metric)
for func in metric:
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
LOSS_MAP[func](self.hparams.sampling_rate)
)
else:
self._metric.append(LOSS_MAP[func]())
else:
raise ValueError(f"Invalid metrics {func}")
@property
def dataset(self):
@ -109,6 +119,9 @@ class Model(pl.LightningModule):
def val_dataloader(self):
return self.dataset.val_dataloader()
def test_dataloader(self):
return self.dataset.test_dataloader()
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr)
@ -140,9 +153,7 @@ class Model(pl.LightningModule):
target = batch["clean"]
prediction = self(mixed_waveform)
metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target)
self.log("val_metric", metric_val.item())
self.log("val_loss", loss_val.item())
if (
@ -156,14 +167,27 @@ class Model(pl.LightningModule):
value=loss_val.item(),
step=self.global_step,
)
self.logger.experiment.log_metric(
return {"loss": loss_val}
def test_step(self, batch, batch_idx):
metric_dict = {}
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
for metric in self.metric:
value = metric(target, prediction)
metric_dict[metric.name] = value
self.logger.experiment.log_metrics(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
metrics=metric_dict,
step=self.global_step,
)
return {"loss": loss_val}
return metric_dict
def on_save_checkpoint(self, checkpoint):

View File

@ -5,7 +5,9 @@ joblib>=1.2.0
librosa>=0.9.2
mlflow>=1.29.0
numpy>=1.23.3
pesq==0.0.4
protobuf>=3.19.6
pystoi==0.3.3
pytest-lazy-fixture>=0.6.3
pytorch-lightning>=1.7.7
scikit-learn>=1.1.2