diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index abeccc8..7d25af8 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,10 +4,16 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, +) from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau +# from torch_audiomentations import Compose, Shift + os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") @@ -25,8 +31,13 @@ def main(config: DictConfig): ) parameters = config.hyperparameters + # apply_augmentations = Compose( + # [ + # Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), + # ] + # ) - dataset = instantiate(config.dataset) + dataset = instantiate(config.dataset, augmentations=None) model = instantiate( config.model, dataset=dataset, @@ -45,6 +56,8 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) + callbacks.append(LearningRateMonitor(logging_interval="epoch")) + if parameters.get("Early_stop", False): early_stopping = EarlyStopping( monitor=f"valid_{parameters.get('EarlyStopping_metric','loss')}", @@ -56,11 +69,11 @@ def main(config: DictConfig): ) callbacks.append(early_stopping) - def configure_optimizer(self): + def configure_optimizers(self): optimizer = instantiate( config.optimizer, lr=parameters.get("lr"), - parameters=self.parameters(), + params=self.parameters(), ) scheduler = ReduceLROnPlateau( optimizer=optimizer, @@ -70,9 +83,13 @@ def main(config: DictConfig): min_lr=parameters.get("min_lr", 1e-6), patience=parameters.get("ReduceLr_patience", 3), ) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', + } - model.configure_parameters = MethodType(configure_optimizer, model) + model.configure_optimizers = MethodType(configure_optimizers, model) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 0b5ba7f..c33d29a 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 +stride : 2 sampling_rate: 16000 -batch_size: 128 -valid_minutes : 10 - +batch_size: 32 +valid_minutes : 15 files: train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 7674906..1782ea9 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -1,8 +1,7 @@ -loss : mse +loss : mae metric : [stoi,pesq,si-sdr] lr : 0.0003 -ReduceLr_patience : 10 -ReduceLr_factor : 0.5 -min_lr : 0.00 -early_stop : True +ReduceLr_patience : 5 +ReduceLr_factor : 0.2 +min_lr : 0.000001 EarlyStopping_factor : 10 diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml index e8893f6..d597333 100644 --- a/enhancer/cli/train_config/mlflow/experiment.yaml +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ experiment_name : shahules/enhancer -run_name : baseline +run_name : Demucs + Vtck with stride + augmentations diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml index e15dbc9..0a051b5 100644 --- a/enhancer/cli/train_config/model/Demucs.yaml +++ b/enhancer/cli/train_config/model/Demucs.yaml @@ -1,11 +1,11 @@ _target_: enhancer.models.demucs.Demucs num_channels: 1 -resample: 2 +resample: 4 sampling_rate : 16000 encoder_decoder: - depth: 5 - initial_output_channels: 32 + depth: 4 + initial_output_channels: 64 kernel_size: 8 stride: 4 growth_factor: 2 diff --git a/enhancer/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml index d641bcd..29d48c7 100644 --- a/enhancer/cli/train_config/model/WaveUnet.yaml +++ b/enhancer/cli/train_config/model/WaveUnet.yaml @@ -1,5 +1,5 @@ _target_: enhancer.models.waveunet.WaveUnet num_channels : 1 -depth : 12 +depth : 9 initial_output_channels: 24 sampling_rate : 16000 diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 2d6422c..8bdf60f 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -9,7 +9,7 @@ benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: 1 +devices: 2 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True @@ -22,9 +22,10 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 100 +log_every_n_steps: 50 max_epochs: 200 -max_time: 00:47:00:00 +max_steps: -1 +max_time: null min_epochs: 1 min_steps: null move_metrics_to_cpu: False @@ -37,7 +38,7 @@ precision: 32 profiler: null reload_dataloaders_every_n_epochs: 0 replace_sampler_ddp: True -strategy: null +strategy: ddp sync_batchnorm: False tpu_cores: null track_grad_norm: -1 diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 624a796..dac2c50 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,14 +1,15 @@ import math import multiprocessing import os -import random -from itertools import chain, cycle +from pathlib import Path from typing import Optional +import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data import DataLoader, Dataset +from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor from enhancer.utils import check_files @@ -16,13 +17,15 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng +LARGE_NUM = 2147483647 -class TrainDataset(IterableDataset): + +class TrainDataset(Dataset): def __init__(self, dataset): self.dataset = dataset - def __iter__(self): - return self.dataset.train__iter__() + def __getitem__(self, idx): + return self.dataset.train__getitem__(idx) def __len__(self): return self.dataset.train__len__() @@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule): matching_function=None, batch_size=32, num_workers: Optional[int] = None, + augmentations: Optional[Compose] = None, ): super().__init__() @@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule): else: raise ValueError("valid_minutes must be greater than 0") + self.augmentations = augmentations + def setup(self, stage: Optional[str] = None): """ prepare train/validation/test data splits @@ -115,16 +121,29 @@ class TaskDataset(pl.LightningDataModule): ): valid_minutes *= 60 - valid_min_now = 0.0 + valid_sec_now = 0.0 valid_indices = [] - random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state) - rng.shuffle(random_indices) - i = 0 - while valid_min_now <= valid_minutes: - valid_indices.append(random_indices[i]) - valid_min_now += data[random_indices[i]]["duration"] - i += 1 + all_speakers = np.unique( + [ + (Path(file["clean"]).name.split("_")[0], file["duration"]) + for file in data + ] + ) + possible_indices = list(range(0, len(all_speakers))) + rng = create_unique_rng(len(all_speakers)) + + while valid_sec_now <= valid_minutes: + speaker_index = rng.choice(possible_indices) + possible_indices.remove(speaker_index) + speaker_name = all_speakers[speaker_index] + file_indices = [ + i + for i, file in enumerate(data) + if speaker_name == Path(file["clean"]).name.split("_")[0] + ] + for i in file_indices: + valid_indices.append(i) + valid_sec_now += data[i]["duration"] train_data = [ item for i, item in enumerate(data) if i not in valid_indices @@ -135,16 +154,11 @@ class TaskDataset(pl.LightningDataModule): def prepare_traindata(self, data): train_data = [] for item in data: - samples_metadata = [] clean, noisy, total_dur = item.values() num_segments = self.get_num_segments( total_dur, self.duration, self.stride ) - for index in range(num_segments): - start = index * self.stride - samples_metadata.append( - ({"clean": clean, "noisy": noisy}, start) - ) + samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments) train_data.append(samples_metadata) return train_data @@ -166,7 +180,9 @@ class TaskDataset(pl.LightningDataModule): if total_dur < self.duration: metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) else: - num_segments = round(total_dur / self.duration) + num_segments = self.get_num_segments( + total_dur, self.duration, self.duration + ) for index in range(num_segments): start_time = index * self.duration metadata.append( @@ -175,31 +191,44 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - output = {"noisy": [], "clean": []} + + output = {"clean": [], "noisy": []} for item in batch: - output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) + output["noisy"].append(item["noisy"]) output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) + + if self.augmentations is not None: + noise = output["noisy"] - output["clean"] + output["clean"] = self.augmentations( + output["clean"], sample_rate=self.sampling_rate + ) + self.augmentations.freeze_parameters() + output["noisy"] = ( + self.augmentations(noise, sample_rate=self.sampling_rate) + + output["clean"] + ) + return output - def worker_init_fn(self, _): - worker_info = torch.utils.data.get_worker_info() - dataset = worker_info.dataset - worker_id = worker_info.id - split_size = len(dataset.dataset.train_data) // worker_info.num_workers - dataset.data = dataset.dataset.train_data[ - worker_id * split_size : (worker_id + 1) * split_size - ] + @property + def generator(self): + generator = torch.Generator() + if hasattr(self, "model"): + seed = self.model.current_epoch + LARGE_NUM + else: + seed = LARGE_NUM + return generator.manual_seed(seed) def train_dataloader(self): return DataLoader( TrainDataset(self), - batch_size=None, + batch_size=self.batch_size, num_workers=self.num_workers, + generator=self.generator, collate_fn=self.train_collatefn, - worker_init_fn=self.worker_init_fn, ) def val_dataloader(self): @@ -256,6 +285,7 @@ class EnhancerDataset(TaskDataset): matching_function=None, batch_size=32, num_workers: Optional[int] = None, + augmentations: Optional[Compose] = None, ): super().__init__( @@ -268,6 +298,7 @@ class EnhancerDataset(TaskDataset): matching_function=matching_function, batch_size=batch_size, num_workers=num_workers, + augmentations=augmentations, ) self.sampling_rate = sampling_rate @@ -280,35 +311,17 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) - def random_sample(self, train_data): - return random.sample(train_data, len(train_data)) + def train__getitem__(self, idx): - def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) - train_data = rng.sample(self.train_data, len(self.train_data)) - return zip( - *[ - self.get_stream(self.random_sample(train_data)) - for i in range(self.batch_size) - ] - ) - - def get_stream(self, data): - return chain.from_iterable(map(self.process_data, cycle(data))) - - def process_data(self, data): - for item in data: - yield self.prepare_segment(*item) - - @staticmethod - def get_num_segments(file_duration, duration, stride): - - if file_duration < duration: - num_segments = 1 - else: - num_segments = math.ceil((file_duration - duration) / stride) + 1 - - return num_segments + for filedict, num_samples in self.train_data: + if idx >= num_samples: + idx -= num_samples + continue + else: + start = 0 + if self.duration is not None: + start = idx * self.stride + return self.prepare_segment(filedict, start) def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) @@ -348,7 +361,8 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - return sum([len(item) for item in self.train_data]) // (self.batch_size) + _, num_examples = list(zip(*self.train_data)) + return sum(num_examples) def val__len__(self): return len(self._validation) diff --git a/enhancer/loss.py b/enhancer/loss.py index 2150699..32b30cf 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch import torch.nn as nn -from pesq import pesq +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility @@ -66,7 +66,7 @@ class Si_SDR: "Invalid reduction, valid options are sum, mean, None" ) self.higher_better = False - self.name = "Si-SDR" + self.name = "si-sdr" def __call__(self, prediction: torch.Tensor, target: torch.Tensor): @@ -122,20 +122,16 @@ class Pesq: self.sr = sr self.name = "pesq" self.mode = mode + self.pesq = PerceptualEvaluationSpeechQuality( + fs=self.sr, mode=self.mode + ) def __call__(self, prediction: torch.Tensor, target: torch.Tensor): pesq_values = [] for pred, target_ in zip(prediction, target): try: - pesq_values.append( - pesq( - self.sr, - target_.squeeze().detach().cpu().numpy(), - pred.squeeze().detach().cpu().numpy(), - self.mode, - ) - ) + pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) except Exception as e: logging.warning(f"{e} error occured while calculating PESQ") return torch.tensor(np.mean(pesq_values)) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index dc5219d..3b60b85 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -113,6 +113,8 @@ class Model(pl.LightningModule): if stage == "fit": torch.cuda.empty_cache() self.dataset.setup(stage) + self.dataset.model = self + print( "Total train duration", self.dataset.train_dataloader().dataset.__len__() @@ -134,7 +136,6 @@ class Model(pl.LightningModule): / 60, "minutes", ) - self.dataset.model = self def train_dataloader(self): return self.dataset.train_dataloader() diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 9e9ce32..d151ef8 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -70,7 +70,7 @@ class Audio: if sampling_rate: audio = self.__class__.resample_audio( - audio, self.sampling_rate, sampling_rate + audio, sampling_rate, self.sampling_rate ) if self.return_tensor: return torch.tensor(audio) diff --git a/requirements.txt b/requirements.txt index fa5e41c..fb54920 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ joblib>=1.2.0 librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 -git+https://github.com/ludlows/python-pesq#egg=pesq +pesq==0.0.4 protobuf>=3.19.6 pystoi==0.3.3 pytest-lazy-fixture>=0.6.3 @@ -14,5 +14,6 @@ scikit-learn>=1.1.2 scipy>=1.9.1 soundfile>=0.11.0 torch>=1.12.1 +torch-audiomentations==0.11.0 torchaudio>=0.12.1 tqdm>=4.64.1