Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
5d8f49d78e
|
|
@ -75,6 +75,7 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
trainer.test(model)
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
logger.experiment.log_artifact(
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
|
@ -36,12 +37,24 @@ class ValidDataset(Dataset):
|
||||||
return self.dataset.val__len__()
|
return self.dataset.val__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.test__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.test__len__()
|
||||||
|
|
||||||
|
|
||||||
class TaskDataset(pl.LightningDataModule):
|
class TaskDataset(pl.LightningDataModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
|
valid_size: float = 0.20,
|
||||||
duration: float = 1.0,
|
duration: float = 1.0,
|
||||||
sampling_rate: int = 48000,
|
sampling_rate: int = 48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
|
|
@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
if valid_size > 0.0:
|
||||||
|
self.valid_size = valid_size
|
||||||
|
else:
|
||||||
|
raise ValueError("valid_size must be greater than 0")
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
prepare train/validation/test data splits
|
||||||
|
"""
|
||||||
|
|
||||||
if stage in ("fit", None):
|
if stage in ("fit", None):
|
||||||
|
|
||||||
|
|
@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
fp = Fileprocessor.from_name(
|
fp = Fileprocessor.from_name(
|
||||||
self.name, train_clean, train_noisy, self.matching_function
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
)
|
)
|
||||||
self.train_data = fp.prepare_matching_dict()
|
train_data = fp.prepare_matching_dict()
|
||||||
|
self.train_data, self.val_data = train_test_split(
|
||||||
val_clean = os.path.join(self.root_dir, self.files.test_clean)
|
train_data, test_size=0.20, shuffle=True, random_state=42
|
||||||
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
|
||||||
fp = Fileprocessor.from_name(
|
|
||||||
self.name, val_clean, val_noisy, self.matching_function
|
|
||||||
)
|
)
|
||||||
val_data = fp.prepare_matching_dict()
|
|
||||||
|
|
||||||
for item in val_data:
|
self._validation = self.prepare_mapstype(self.val_data)
|
||||||
|
|
||||||
|
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||||
|
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
self.name, test_clean, test_noisy, self.matching_function
|
||||||
|
)
|
||||||
|
test_data = fp.prepare_matching_dict()
|
||||||
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
|
def prepare_mapstype(self, data):
|
||||||
|
|
||||||
|
metadata = []
|
||||||
|
for item in data:
|
||||||
clean, noisy, total_dur = item.values()
|
clean, noisy, total_dur = item.values()
|
||||||
if total_dur < self.duration:
|
if total_dur < self.duration:
|
||||||
continue
|
continue
|
||||||
num_segments = round(total_dur / self.duration)
|
num_segments = round(total_dur / self.duration)
|
||||||
for index in range(num_segments):
|
for index in range(num_segments):
|
||||||
start_time = index * self.duration
|
start_time = index * self.duration
|
||||||
self._validation.append(
|
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
|
||||||
({"clean": clean, "noisy": noisy}, start_time)
|
return metadata
|
||||||
)
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
|
|
@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
TestDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class EnhancerDataset(TaskDataset):
|
||||||
"""
|
"""
|
||||||
|
|
@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
|
valid_size=0.2,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
|
|
@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name=name,
|
name=name,
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
files=files,
|
files=files,
|
||||||
|
valid_size=valid_size,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
|
|
@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset):
|
||||||
def val__getitem__(self, idx):
|
def val__getitem__(self, idx):
|
||||||
return self.prepare_segment(*self._validation[idx])
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
||||||
|
def test__getitem__(self, idx):
|
||||||
|
return self.prepare_segment(*self._test[idx])
|
||||||
|
|
||||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||||
|
|
||||||
clean_segment = self.audio(
|
clean_segment = self.audio(
|
||||||
|
|
@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
||||||
|
def test__len__(self):
|
||||||
|
return len(self._test)
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class ProcessorFunctions:
|
||||||
One clean audio have multiple noisy audio files
|
One clean audio have multiple noisy audio files
|
||||||
"""
|
"""
|
||||||
|
|
||||||
matching_wavfiles = dict()
|
matching_wavfiles = list()
|
||||||
clean_filenames = [
|
clean_filenames = [
|
||||||
file.split("/")[-1]
|
file.split("/")[-1]
|
||||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||||
|
|
@ -73,7 +73,7 @@ class ProcessorFunctions:
|
||||||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||||
sr_clean == sr_noisy
|
sr_clean == sr_noisy
|
||||||
):
|
):
|
||||||
matching_wavfiles.update(
|
matching_wavfiles.append(
|
||||||
{
|
{
|
||||||
"clean": os.path.join(clean_path, clean_file),
|
"clean": os.path.join(clean_path, clean_file),
|
||||||
"noisy": noisy_file,
|
"noisy": noisy_file,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
||||||
class mean_squared_error(nn.Module):
|
class mean_squared_error(nn.Module):
|
||||||
|
|
@ -12,6 +16,7 @@ class mean_squared_error(nn.Module):
|
||||||
|
|
||||||
self.loss_fun = nn.MSELoss(reduction=reduction)
|
self.loss_fun = nn.MSELoss(reduction=reduction)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
self.name = "mse"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -34,6 +39,7 @@ class mean_absolute_error(nn.Module):
|
||||||
|
|
||||||
self.loss_fun = nn.L1Loss(reduction=reduction)
|
self.loss_fun = nn.L1Loss(reduction=reduction)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
self.name = "mae"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -46,13 +52,12 @@ class mean_absolute_error(nn.Module):
|
||||||
return self.loss_fun(prediction, target)
|
return self.loss_fun(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
class Si_SDR(nn.Module):
|
class Si_SDR:
|
||||||
"""
|
"""
|
||||||
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reduction: str = "mean"):
|
def __init__(self, reduction: str = "mean"):
|
||||||
super().__init__()
|
|
||||||
if reduction in ["sum", "mean", None]:
|
if reduction in ["sum", "mean", None]:
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
else:
|
else:
|
||||||
|
|
@ -60,8 +65,9 @@ class Si_SDR(nn.Module):
|
||||||
"Invalid reduction, valid options are sum, mean, None"
|
"Invalid reduction, valid options are sum, mean, None"
|
||||||
)
|
)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
|
self.name = "Si-SDR"
|
||||||
|
|
||||||
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
if prediction.size() != target.size() or target.ndim < 3:
|
if prediction.size() != target.size() or target.ndim < 3:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
@ -90,7 +96,40 @@ class Si_SDR(nn.Module):
|
||||||
return si_sdr
|
return si_sdr
|
||||||
|
|
||||||
|
|
||||||
class Avergeloss(nn.Module):
|
class Stoi:
|
||||||
|
"""
|
||||||
|
STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
|
||||||
|
Note that input will be moved to cpu to perform the metric calculation.
|
||||||
|
parameters:
|
||||||
|
sr: int
|
||||||
|
sampling rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sr: int):
|
||||||
|
self.sr = sr
|
||||||
|
self.stoi = ShortTimeObjectiveIntelligibility(fs=sr)
|
||||||
|
self.name = "stoi"
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
return self.stoi(prediction, target)
|
||||||
|
|
||||||
|
|
||||||
|
class Pesq:
|
||||||
|
def __init__(self, sr: int, mode="nb"):
|
||||||
|
|
||||||
|
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||||
|
self.name = "pesq"
|
||||||
|
|
||||||
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
try:
|
||||||
|
return self.pesq(prediction, target)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class LossWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Combine multiple metics of same nature.
|
Combine multiple metics of same nature.
|
||||||
for example, ["mea","mae"]
|
for example, ["mea","mae"]
|
||||||
|
|
@ -137,4 +176,6 @@ LOSS_MAP = {
|
||||||
"mae": mean_absolute_error,
|
"mae": mean_absolute_error,
|
||||||
"mse": mean_squared_error,
|
"mse": mean_squared_error,
|
||||||
"SI-SDR": Si_SDR,
|
"SI-SDR": Si_SDR,
|
||||||
|
"pesq": Pesq,
|
||||||
|
"stoi": Stoi,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
from enhancer.loss import Avergeloss
|
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||||
from enhancer.version import __version__
|
from enhancer.version import __version__
|
||||||
|
|
||||||
CACHE_DIR = ""
|
CACHE_DIR = ""
|
||||||
|
|
@ -76,7 +76,7 @@ class Model(pl.LightningModule):
|
||||||
if isinstance(loss, str):
|
if isinstance(loss, str):
|
||||||
loss = [loss]
|
loss = [loss]
|
||||||
|
|
||||||
self._loss = Avergeloss(loss)
|
self._loss = LossWrapper(loss)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric(self):
|
def metric(self):
|
||||||
|
|
@ -84,11 +84,21 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self, metric):
|
def metric(self, metric):
|
||||||
|
self._metric = []
|
||||||
if isinstance(metric, str):
|
if isinstance(metric, str):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
|
|
||||||
self._metric = Avergeloss(metric)
|
for func in metric:
|
||||||
|
if func in LOSS_MAP.keys():
|
||||||
|
if func in ("pesq", "stoi"):
|
||||||
|
self._metric.append(
|
||||||
|
LOSS_MAP[func](self.hparams.sampling_rate)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._metric.append(LOSS_MAP[func]())
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid metrics {func}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|
@ -109,6 +119,9 @@ class Model(pl.LightningModule):
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return self.dataset.val_dataloader()
|
return self.dataset.val_dataloader()
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return self.dataset.test_dataloader()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
|
|
||||||
|
|
@ -140,9 +153,7 @@ class Model(pl.LightningModule):
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
metric_val = self.metric(prediction, target)
|
|
||||||
loss_val = self.loss(prediction, target)
|
loss_val = self.loss(prediction, target)
|
||||||
self.log("val_metric", metric_val.item())
|
|
||||||
self.log("val_loss", loss_val.item())
|
self.log("val_loss", loss_val.item())
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
@ -156,14 +167,27 @@ class Model(pl.LightningModule):
|
||||||
value=loss_val.item(),
|
value=loss_val.item(),
|
||||||
step=self.global_step,
|
step=self.global_step,
|
||||||
)
|
)
|
||||||
self.logger.experiment.log_metric(
|
|
||||||
|
return {"loss": loss_val}
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
|
||||||
|
metric_dict = {}
|
||||||
|
mixed_waveform = batch["noisy"]
|
||||||
|
target = batch["clean"]
|
||||||
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
|
for metric in self.metric:
|
||||||
|
value = metric(target, prediction)
|
||||||
|
metric_dict[metric.name] = value
|
||||||
|
|
||||||
|
self.logger.experiment.log_metrics(
|
||||||
run_id=self.logger.run_id,
|
run_id=self.logger.run_id,
|
||||||
key="val_metric",
|
metrics=metric_dict,
|
||||||
value=metric_val.item(),
|
|
||||||
step=self.global_step,
|
step=self.global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"loss": loss_val}
|
return metric_dict
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,9 @@ joblib>=1.2.0
|
||||||
librosa>=0.9.2
|
librosa>=0.9.2
|
||||||
mlflow>=1.29.0
|
mlflow>=1.29.0
|
||||||
numpy>=1.23.3
|
numpy>=1.23.3
|
||||||
|
pesq==0.0.4
|
||||||
protobuf>=3.19.6
|
protobuf>=3.19.6
|
||||||
|
pystoi==0.3.3
|
||||||
pytest-lazy-fixture>=0.6.3
|
pytest-lazy-fixture>=0.6.3
|
||||||
pytorch-lightning>=1.7.7
|
pytorch-lightning>=1.7.7
|
||||||
scikit-learn>=1.1.2
|
scikit-learn>=1.1.2
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue