Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
		
						commit
						5d8f49d78e
					
				|  | @ -75,6 +75,7 @@ def main(config: DictConfig): | ||||||
| 
 | 
 | ||||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) |     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||||
|     trainer.fit(model) |     trainer.fit(model) | ||||||
|  |     trainer.test(model) | ||||||
| 
 | 
 | ||||||
|     logger.experiment.log_artifact( |     logger.experiment.log_artifact( | ||||||
|         logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" |         logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ from typing import Optional | ||||||
| 
 | 
 | ||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  | from sklearn.model_selection import train_test_split | ||||||
| from torch.utils.data import DataLoader, Dataset, IterableDataset | from torch.utils.data import DataLoader, Dataset, IterableDataset | ||||||
| 
 | 
 | ||||||
| from enhancer.data.fileprocessor import Fileprocessor | from enhancer.data.fileprocessor import Fileprocessor | ||||||
|  | @ -36,12 +37,24 @@ class ValidDataset(Dataset): | ||||||
|         return self.dataset.val__len__() |         return self.dataset.val__len__() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class TestDataset(Dataset): | ||||||
|  |     def __init__(self, dataset): | ||||||
|  |         self.dataset = dataset | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, idx): | ||||||
|  |         return self.dataset.test__getitem__(idx) | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return self.dataset.test__len__() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class TaskDataset(pl.LightningDataModule): | class TaskDataset(pl.LightningDataModule): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         name: str, |         name: str, | ||||||
|         root_dir: str, |         root_dir: str, | ||||||
|         files: Files, |         files: Files, | ||||||
|  |         valid_size: float = 0.20, | ||||||
|         duration: float = 1.0, |         duration: float = 1.0, | ||||||
|         sampling_rate: int = 48000, |         sampling_rate: int = 48000, | ||||||
|         matching_function=None, |         matching_function=None, | ||||||
|  | @ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|         if num_workers is None: |         if num_workers is None: | ||||||
|             num_workers = multiprocessing.cpu_count() // 2 |             num_workers = multiprocessing.cpu_count() // 2 | ||||||
|         self.num_workers = num_workers |         self.num_workers = num_workers | ||||||
|  |         if valid_size > 0.0: | ||||||
|  |             self.valid_size = valid_size | ||||||
|  |         else: | ||||||
|  |             raise ValueError("valid_size must be greater than 0") | ||||||
| 
 | 
 | ||||||
|     def setup(self, stage: Optional[str] = None): |     def setup(self, stage: Optional[str] = None): | ||||||
|  |         """ | ||||||
|  |         prepare train/validation/test data splits | ||||||
|  |         """ | ||||||
| 
 | 
 | ||||||
|         if stage in ("fit", None): |         if stage in ("fit", None): | ||||||
| 
 | 
 | ||||||
|  | @ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|             fp = Fileprocessor.from_name( |             fp = Fileprocessor.from_name( | ||||||
|                 self.name, train_clean, train_noisy, self.matching_function |                 self.name, train_clean, train_noisy, self.matching_function | ||||||
|             ) |             ) | ||||||
|             self.train_data = fp.prepare_matching_dict() |             train_data = fp.prepare_matching_dict() | ||||||
| 
 |             self.train_data, self.val_data = train_test_split( | ||||||
|             val_clean = os.path.join(self.root_dir, self.files.test_clean) |                 train_data, test_size=0.20, shuffle=True, random_state=42 | ||||||
|             val_noisy = os.path.join(self.root_dir, self.files.test_noisy) |  | ||||||
|             fp = Fileprocessor.from_name( |  | ||||||
|                 self.name, val_clean, val_noisy, self.matching_function |  | ||||||
|             ) |             ) | ||||||
|             val_data = fp.prepare_matching_dict() |  | ||||||
| 
 | 
 | ||||||
|             for item in val_data: |             self._validation = self.prepare_mapstype(self.val_data) | ||||||
|  | 
 | ||||||
|  |             test_clean = os.path.join(self.root_dir, self.files.test_clean) | ||||||
|  |             test_noisy = os.path.join(self.root_dir, self.files.test_noisy) | ||||||
|  |             fp = Fileprocessor.from_name( | ||||||
|  |                 self.name, test_clean, test_noisy, self.matching_function | ||||||
|  |             ) | ||||||
|  |             test_data = fp.prepare_matching_dict() | ||||||
|  |             self._test = self.prepare_mapstype(test_data) | ||||||
|  | 
 | ||||||
|  |     def prepare_mapstype(self, data): | ||||||
|  | 
 | ||||||
|  |         metadata = [] | ||||||
|  |         for item in data: | ||||||
|             clean, noisy, total_dur = item.values() |             clean, noisy, total_dur = item.values() | ||||||
|             if total_dur < self.duration: |             if total_dur < self.duration: | ||||||
|                 continue |                 continue | ||||||
|             num_segments = round(total_dur / self.duration) |             num_segments = round(total_dur / self.duration) | ||||||
|             for index in range(num_segments): |             for index in range(num_segments): | ||||||
|                 start_time = index * self.duration |                 start_time = index * self.duration | ||||||
|                     self._validation.append( |                 metadata.append(({"clean": clean, "noisy": noisy}, start_time)) | ||||||
|                         ({"clean": clean, "noisy": noisy}, start_time) |         return metadata | ||||||
|                     ) |  | ||||||
| 
 | 
 | ||||||
|     def train_dataloader(self): |     def train_dataloader(self): | ||||||
|         return DataLoader( |         return DataLoader( | ||||||
|  | @ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|             num_workers=self.num_workers, |             num_workers=self.num_workers, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def test_dataloader(self): | ||||||
|  |         return DataLoader( | ||||||
|  |             TestDataset(self), | ||||||
|  |             batch_size=self.batch_size, | ||||||
|  |             num_workers=self.num_workers, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class EnhancerDataset(TaskDataset): | class EnhancerDataset(TaskDataset): | ||||||
|     """ |     """ | ||||||
|  | @ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset): | ||||||
|         name: str, |         name: str, | ||||||
|         root_dir: str, |         root_dir: str, | ||||||
|         files: Files, |         files: Files, | ||||||
|  |         valid_size=0.2, | ||||||
|         duration=1.0, |         duration=1.0, | ||||||
|         sampling_rate=48000, |         sampling_rate=48000, | ||||||
|         matching_function=None, |         matching_function=None, | ||||||
|  | @ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset): | ||||||
|             name=name, |             name=name, | ||||||
|             root_dir=root_dir, |             root_dir=root_dir, | ||||||
|             files=files, |             files=files, | ||||||
|  |             valid_size=valid_size, | ||||||
|             sampling_rate=sampling_rate, |             sampling_rate=sampling_rate, | ||||||
|             duration=duration, |             duration=duration, | ||||||
|             matching_function=matching_function, |             matching_function=matching_function, | ||||||
|  | @ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset): | ||||||
|     def val__getitem__(self, idx): |     def val__getitem__(self, idx): | ||||||
|         return self.prepare_segment(*self._validation[idx]) |         return self.prepare_segment(*self._validation[idx]) | ||||||
| 
 | 
 | ||||||
|  |     def test__getitem__(self, idx): | ||||||
|  |         return self.prepare_segment(*self._test[idx]) | ||||||
|  | 
 | ||||||
|     def prepare_segment(self, file_dict: dict, start_time: float): |     def prepare_segment(self, file_dict: dict, start_time: float): | ||||||
| 
 | 
 | ||||||
|         clean_segment = self.audio( |         clean_segment = self.audio( | ||||||
|  | @ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset): | ||||||
| 
 | 
 | ||||||
|     def val__len__(self): |     def val__len__(self): | ||||||
|         return len(self._validation) |         return len(self._validation) | ||||||
|  | 
 | ||||||
|  |     def test__len__(self): | ||||||
|  |         return len(self._test) | ||||||
|  |  | ||||||
|  | @ -55,7 +55,7 @@ class ProcessorFunctions: | ||||||
|         One clean audio have multiple noisy audio files |         One clean audio have multiple noisy audio files | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         matching_wavfiles = dict() |         matching_wavfiles = list() | ||||||
|         clean_filenames = [ |         clean_filenames = [ | ||||||
|             file.split("/")[-1] |             file.split("/")[-1] | ||||||
|             for file in glob.glob(os.path.join(clean_path, "*.wav")) |             for file in glob.glob(os.path.join(clean_path, "*.wav")) | ||||||
|  | @ -73,7 +73,7 @@ class ProcessorFunctions: | ||||||
|                 if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( |                 if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( | ||||||
|                     sr_clean == sr_noisy |                     sr_clean == sr_noisy | ||||||
|                 ): |                 ): | ||||||
|                     matching_wavfiles.update( |                     matching_wavfiles.append( | ||||||
|                         { |                         { | ||||||
|                             "clean": os.path.join(clean_path, clean_file), |                             "clean": os.path.join(clean_path, clean_file), | ||||||
|                             "noisy": noisy_file, |                             "noisy": noisy_file, | ||||||
|  |  | ||||||
|  | @ -1,5 +1,9 @@ | ||||||
|  | import logging | ||||||
|  | 
 | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
|  | from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality | ||||||
|  | from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class mean_squared_error(nn.Module): | class mean_squared_error(nn.Module): | ||||||
|  | @ -12,6 +16,7 @@ class mean_squared_error(nn.Module): | ||||||
| 
 | 
 | ||||||
|         self.loss_fun = nn.MSELoss(reduction=reduction) |         self.loss_fun = nn.MSELoss(reduction=reduction) | ||||||
|         self.higher_better = False |         self.higher_better = False | ||||||
|  |         self.name = "mse" | ||||||
| 
 | 
 | ||||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): |     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
| 
 | 
 | ||||||
|  | @ -34,6 +39,7 @@ class mean_absolute_error(nn.Module): | ||||||
| 
 | 
 | ||||||
|         self.loss_fun = nn.L1Loss(reduction=reduction) |         self.loss_fun = nn.L1Loss(reduction=reduction) | ||||||
|         self.higher_better = False |         self.higher_better = False | ||||||
|  |         self.name = "mae" | ||||||
| 
 | 
 | ||||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): |     def forward(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
| 
 | 
 | ||||||
|  | @ -46,13 +52,12 @@ class mean_absolute_error(nn.Module): | ||||||
|         return self.loss_fun(prediction, target) |         return self.loss_fun(prediction, target) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Si_SDR(nn.Module): | class Si_SDR: | ||||||
|     """ |     """ | ||||||
|     SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) |     SI-SDR metric based on SDR – HALF-BAKED OR WELL DONE?(https://arxiv.org/pdf/1811.02508.pdf) | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, reduction: str = "mean"): |     def __init__(self, reduction: str = "mean"): | ||||||
|         super().__init__() |  | ||||||
|         if reduction in ["sum", "mean", None]: |         if reduction in ["sum", "mean", None]: | ||||||
|             self.reduction = reduction |             self.reduction = reduction | ||||||
|         else: |         else: | ||||||
|  | @ -60,8 +65,9 @@ class Si_SDR(nn.Module): | ||||||
|                 "Invalid reduction, valid options are sum, mean, None" |                 "Invalid reduction, valid options are sum, mean, None" | ||||||
|             ) |             ) | ||||||
|         self.higher_better = False |         self.higher_better = False | ||||||
|  |         self.name = "Si-SDR" | ||||||
| 
 | 
 | ||||||
|     def forward(self, prediction: torch.Tensor, target: torch.Tensor): |     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
| 
 | 
 | ||||||
|         if prediction.size() != target.size() or target.ndim < 3: |         if prediction.size() != target.size() or target.ndim < 3: | ||||||
|             raise TypeError( |             raise TypeError( | ||||||
|  | @ -90,7 +96,40 @@ class Si_SDR(nn.Module): | ||||||
|         return si_sdr |         return si_sdr | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Avergeloss(nn.Module): | class Stoi: | ||||||
|  |     """ | ||||||
|  |     STOI (Short-Time Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. | ||||||
|  |     Note that input will be moved to cpu to perform the metric calculation. | ||||||
|  |     parameters: | ||||||
|  |         sr: int | ||||||
|  |             sampling rate | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, sr: int): | ||||||
|  |         self.sr = sr | ||||||
|  |         self.stoi = ShortTimeObjectiveIntelligibility(fs=sr) | ||||||
|  |         self.name = "stoi" | ||||||
|  | 
 | ||||||
|  |     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
|  | 
 | ||||||
|  |         return self.stoi(prediction, target) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Pesq: | ||||||
|  |     def __init__(self, sr: int, mode="nb"): | ||||||
|  | 
 | ||||||
|  |         self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) | ||||||
|  |         self.name = "pesq" | ||||||
|  | 
 | ||||||
|  |     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
|  |         try: | ||||||
|  |             return self.pesq(prediction, target) | ||||||
|  |         except Exception as e: | ||||||
|  |             logging.warning(f"{e} error occured while calculating PESQ") | ||||||
|  |             return 0.0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class LossWrapper(nn.Module): | ||||||
|     """ |     """ | ||||||
|     Combine multiple metics of same nature. |     Combine multiple metics of same nature. | ||||||
|     for example, ["mea","mae"] |     for example, ["mea","mae"] | ||||||
|  | @ -137,4 +176,6 @@ LOSS_MAP = { | ||||||
|     "mae": mean_absolute_error, |     "mae": mean_absolute_error, | ||||||
|     "mse": mean_squared_error, |     "mse": mean_squared_error, | ||||||
|     "SI-SDR": Si_SDR, |     "SI-SDR": Si_SDR, | ||||||
|  |     "pesq": Pesq, | ||||||
|  |     "stoi": Stoi, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ from torch.optim import Adam | ||||||
| 
 | 
 | ||||||
| from enhancer.data.dataset import EnhancerDataset | from enhancer.data.dataset import EnhancerDataset | ||||||
| from enhancer.inference import Inference | from enhancer.inference import Inference | ||||||
| from enhancer.loss import Avergeloss | from enhancer.loss import LOSS_MAP, LossWrapper | ||||||
| from enhancer.version import __version__ | from enhancer.version import __version__ | ||||||
| 
 | 
 | ||||||
| CACHE_DIR = "" | CACHE_DIR = "" | ||||||
|  | @ -76,7 +76,7 @@ class Model(pl.LightningModule): | ||||||
|         if isinstance(loss, str): |         if isinstance(loss, str): | ||||||
|             loss = [loss] |             loss = [loss] | ||||||
| 
 | 
 | ||||||
|         self._loss = Avergeloss(loss) |         self._loss = LossWrapper(loss) | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def metric(self): |     def metric(self): | ||||||
|  | @ -84,11 +84,21 @@ class Model(pl.LightningModule): | ||||||
| 
 | 
 | ||||||
|     @metric.setter |     @metric.setter | ||||||
|     def metric(self, metric): |     def metric(self, metric): | ||||||
| 
 |         self._metric = [] | ||||||
|         if isinstance(metric, str): |         if isinstance(metric, str): | ||||||
|             metric = [metric] |             metric = [metric] | ||||||
| 
 | 
 | ||||||
|         self._metric = Avergeloss(metric) |         for func in metric: | ||||||
|  |             if func in LOSS_MAP.keys(): | ||||||
|  |                 if func in ("pesq", "stoi"): | ||||||
|  |                     self._metric.append( | ||||||
|  |                         LOSS_MAP[func](self.hparams.sampling_rate) | ||||||
|  |                     ) | ||||||
|  |                 else: | ||||||
|  |                     self._metric.append(LOSS_MAP[func]()) | ||||||
|  | 
 | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f"Invalid metrics {func}") | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def dataset(self): |     def dataset(self): | ||||||
|  | @ -109,6 +119,9 @@ class Model(pl.LightningModule): | ||||||
|     def val_dataloader(self): |     def val_dataloader(self): | ||||||
|         return self.dataset.val_dataloader() |         return self.dataset.val_dataloader() | ||||||
| 
 | 
 | ||||||
|  |     def test_dataloader(self): | ||||||
|  |         return self.dataset.test_dataloader() | ||||||
|  | 
 | ||||||
|     def configure_optimizers(self): |     def configure_optimizers(self): | ||||||
|         return Adam(self.parameters(), lr=self.hparams.lr) |         return Adam(self.parameters(), lr=self.hparams.lr) | ||||||
| 
 | 
 | ||||||
|  | @ -140,9 +153,7 @@ class Model(pl.LightningModule): | ||||||
|         target = batch["clean"] |         target = batch["clean"] | ||||||
|         prediction = self(mixed_waveform) |         prediction = self(mixed_waveform) | ||||||
| 
 | 
 | ||||||
|         metric_val = self.metric(prediction, target) |  | ||||||
|         loss_val = self.loss(prediction, target) |         loss_val = self.loss(prediction, target) | ||||||
|         self.log("val_metric", metric_val.item()) |  | ||||||
|         self.log("val_loss", loss_val.item()) |         self.log("val_loss", loss_val.item()) | ||||||
| 
 | 
 | ||||||
|         if ( |         if ( | ||||||
|  | @ -156,14 +167,27 @@ class Model(pl.LightningModule): | ||||||
|                 value=loss_val.item(), |                 value=loss_val.item(), | ||||||
|                 step=self.global_step, |                 step=self.global_step, | ||||||
|             ) |             ) | ||||||
|             self.logger.experiment.log_metric( | 
 | ||||||
|  |         return {"loss": loss_val} | ||||||
|  | 
 | ||||||
|  |     def test_step(self, batch, batch_idx): | ||||||
|  | 
 | ||||||
|  |         metric_dict = {} | ||||||
|  |         mixed_waveform = batch["noisy"] | ||||||
|  |         target = batch["clean"] | ||||||
|  |         prediction = self(mixed_waveform) | ||||||
|  | 
 | ||||||
|  |         for metric in self.metric: | ||||||
|  |             value = metric(target, prediction) | ||||||
|  |             metric_dict[metric.name] = value | ||||||
|  | 
 | ||||||
|  |         self.logger.experiment.log_metrics( | ||||||
|             run_id=self.logger.run_id, |             run_id=self.logger.run_id, | ||||||
|                 key="val_metric", |             metrics=metric_dict, | ||||||
|                 value=metric_val.item(), |  | ||||||
|             step=self.global_step, |             step=self.global_step, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         return {"loss": loss_val} |         return metric_dict | ||||||
| 
 | 
 | ||||||
|     def on_save_checkpoint(self, checkpoint): |     def on_save_checkpoint(self, checkpoint): | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,7 +5,9 @@ joblib>=1.2.0 | ||||||
| librosa>=0.9.2 | librosa>=0.9.2 | ||||||
| mlflow>=1.29.0 | mlflow>=1.29.0 | ||||||
| numpy>=1.23.3 | numpy>=1.23.3 | ||||||
|  | pesq==0.0.4 | ||||||
| protobuf>=3.19.6 | protobuf>=3.19.6 | ||||||
|  | pystoi==0.3.3 | ||||||
| pytest-lazy-fixture>=0.6.3 | pytest-lazy-fixture>=0.6.3 | ||||||
| pytorch-lightning>=1.7.7 | pytorch-lightning>=1.7.7 | ||||||
| scikit-learn>=1.1.2 | scikit-learn>=1.1.2 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786