From 1dcaa7fbe05313487c34fc68c28640e4cc130036 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 09:05:02 +0530 Subject: [PATCH 01/10] change config --- cli/train_config/config.yaml | 2 +- cli/train_config/dataset/Vctk.yaml | 2 +- cli/train_config/hyperparameters/default.yaml | 1 - cli/train_config/trainer/default.yaml | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index d1c8646..aed10ac 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -3,7 +3,7 @@ name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 sampling_rate: 16000 -batch_size: 8 +batch_size: 64 files: train_clean : clean_trainset_56spk_wav diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 04b099b..a0ac704 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,4 +1,3 @@ loss : mse metric : mae lr : 0.0001 -num_epochs : 100 diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index ab4e273..6c693d8 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -22,8 +22,8 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 10 -max_epochs: 100 +log_every_n_steps: 50 +max_epochs: 500 max_steps: null max_time: null min_epochs: 1 From f1604b0f0e568b0d4495db8d01a4557f21a2668b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sun, 2 Oct 2022 10:01:44 +0530 Subject: [PATCH 02/10] demucs model --- cli/train_config/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 61551bd..6b5d98e 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : WaveUnet + - model : Demucs - dataset : Vctk - optimizer : Adam - hyperparameters : default From c8815fa969dc54c6560f4d1b29764448b7cd0e2f Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 12:24:15 +0530 Subject: [PATCH 03/10] vctk 28 spkrs --- cli/train_config/dataset/Vctk.yaml | 4 ++-- cli/train_config/hyperparameters/default.yaml | 1 - cli/train_config/trainer/default.yaml | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index aed10ac..129d9a8 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -6,9 +6,9 @@ sampling_rate: 16000 batch_size: 64 files: - train_clean : clean_trainset_56spk_wav + train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav - train_noisy : noisy_trainset_56spk_wav + train_noisy : noisy_trainset_28spk_wav test_noisy : noisy_testset_wav diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 49a7f80..82ac3c2 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,7 +1,6 @@ loss : mse metric : mae lr : 0.0001 -num_epochs : 100 ReduceLr_patience : 5 ReduceLr_factor : 0.1 min_lr : 0.000001 diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index 2aa5083..55101de 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -22,8 +22,8 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 10 -max_epochs: 30 +log_every_n_steps: 1 +max_epochs: 10 max_steps: null max_time: null min_epochs: 1 From 158591176708c817bdaec3a9180fbe4f06564c0a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 16:09:05 +0530 Subject: [PATCH 04/10] ckpt --- cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train.py b/cli/train.py index 8a4beaf..cd19ffc 100644 --- a/cli/train.py +++ b/cli/train.py @@ -56,7 +56,7 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) if os.path.exists("./model/"): - logger.experiment.log_artifact(logger.run_id,f"./model/.*") + logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt") From 07c525ca150fc3bf81b3d7272689cee355243404 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 20:00:14 +0530 Subject: [PATCH 05/10] fix key error --- enhancer/models/model.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index b1bdd86..5827301 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -191,8 +191,8 @@ class Model(pl.LightningModule): map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl,map_location) - module_name = loaded_checkpoint["architecture"]["module"] - class_name = loaded_checkpoint["architecture"]["class"] + module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) @@ -216,11 +216,12 @@ class Model(pl.LightningModule): batch_predictions = [] self.eval().to(self.device) - for batch_id in range(batch.shape[0],batch_size): - batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) - prediction = self(batch_data) - batch_predictions.append(prediction) - + with torch.no_grad(): + for batch_id in range(0,batch.shape[0],batch_size): + batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + prediction = self(batch_data) + batch_predictions.append(prediction) + return torch.vstack(batch_predictions) def enhance( @@ -232,9 +233,9 @@ class Model(pl.LightningModule): duration:Optional[int]=None, step_size:Optional[int]=None,): - model_sampling_rate = self.model.hprams("sampling_rate") + model_sampling_rate = self.hparams["sampling_rate"] if duration is None: - duration = self.model.hparams("duration") + duration = self.hparams["duration"] waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform.to(self.device) window_size = round(duration * model_sampling_rate) @@ -246,8 +247,9 @@ class Model(pl.LightningModule): Inference.write_output(waveform,audio,model_sampling_rate) else: - return waveform - + waveform = Inference.prepare_output(waveform, model_sampling_rate, + audio, sampling_rate) + return waveform @property def valid_monitor(self): From 5e5fd9d9b02a3f5485a26502c71041ecbdd9add4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 20:00:35 +0530 Subject: [PATCH 06/10] prepare output type/sr --- enhancer/inference.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 6e9cff7..2c63d54 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -18,6 +18,8 @@ class Inference: if isinstance(audio,(np.ndarray,torch.Tensor)): assert sr is not None, "Invalid sampling rate!" + if len(audio.shape) == 1: + audio = audio.reshape(1,-1) if isinstance(audio,str): audio = Path(audio) @@ -65,6 +67,8 @@ class Inference: window = get_window(window=window,Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window + step_size = window_size//2 if step_size is None else step_size + data = data.permute(1,2,0) data = F.fold(data, @@ -84,7 +88,18 @@ class Inference: raise FileExistsError(f"file {filename} already exists") else: wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) - + + @staticmethod + def prepare_output(waveform:torch.Tensor, model_sampling_rate:int, + audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int] + ): + if isinstance(audio,np.ndarray): + waveform = waveform.detach().cpu().numpy() + + if sampling_rate!=None: + waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate) + + return waveform From ecd47905dd95256cffa52ffb40967a9f75ee4f1c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 20:00:49 +0530 Subject: [PATCH 07/10] relative imports --- enhancer/models/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 0d21337..534a608 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,2 +1,3 @@ from enhancer.models.demucs import Demucs -from enhancer.models.waveunet import WaveUnet \ No newline at end of file +from enhancer.models.waveunet import WaveUnet +from enhancer.models.model import Model \ No newline at end of file From a880125322c9201c5c3059248f130c17001b7971 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 20:01:19 +0530 Subject: [PATCH 08/10] fix logging path --- cli/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cli/train.py b/cli/train.py index cd19ffc..dee3d2e 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,3 +1,4 @@ +from genericpath import isfile import os from types import MethodType import hydra @@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID") +JOB_ID = os.environ.get("SLURM_JOBID","0") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): @@ -55,8 +56,10 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - if os.path.exists("./model/"): - logger.experiment.log_artifact(logger.run_id,f"model_{JOB_ID}.ckpt") + + saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id,saved_location) From 687f67e40c2c5b94ab2da160fca40bad529854f4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 21:21:27 +0530 Subject: [PATCH 09/10] write output fix --- enhancer/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 2c63d54..fd2f57a 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -84,10 +84,15 @@ class Inference: if isinstance(filename,str): filename = Path(filename) + + parent, name = filename.parent, "cleaned_"+filename.name + filename = parent/Path(name) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: - wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) + if isinstance(waveform,torch.Tensor): + waveform = waveform.detach().cpu().squeeze().numpy() + wavfile.write(filename,rate=sr,data=waveform) @staticmethod def prepare_output(waveform:torch.Tensor, model_sampling_rate:int, From c609b57309ca841107343b65921308cbcc0ded3a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 21:22:15 +0530 Subject: [PATCH 10/10] fix typo --- enhancer/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 5827301..de2edab 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -249,7 +249,7 @@ class Model(pl.LightningModule): else: waveform = Inference.prepare_output(waveform, model_sampling_rate, audio, sampling_rate) - return waveform + return waveform @property def valid_monitor(self):