From 0a80521c0295223f0d42fc7b84dbddc9bcd4584e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 11:55:17 +0530 Subject: [PATCH 01/35] rmv matplotlib --- enhancer/utils/io.py | 1 - 1 file changed, 1 deletion(-) diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 80bd1a4..afc19e8 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,7 +1,6 @@ import os import librosa from typing import Optional -from matplotlib.pyplot import axis import numpy as np import torch import torchaudio From 6c595e844621c6233b99ea754a650dca82518a65 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 12:06:06 +0530 Subject: [PATCH 02/35] torchaudio --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index c74e46d..e7fcd24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,6 @@ tqdm==4.64.0 mlflow==1.23.1 protobuf==3.19.3 boto3==1.23.9 +torchaudio==0.10.2 huggingface-hub==0.4.0 pytorch-lightning==1.5.10 From 443c0a93f5cc8a1bb91374b5feb079dca3d652ed Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 12:49:41 +0530 Subject: [PATCH 03/35] set hydra to full error --- hpc_entrypoint.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hpc_entrypoint.sh b/hpc_entrypoint.sh index 2a4e3c7..7372eb9 100644 --- a/hpc_entrypoint.sh +++ b/hpc_entrypoint.sh @@ -22,6 +22,8 @@ echo "Activate Environment" source activate enhancer export TRANSFORMERS_OFFLINE=True export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer +export HYDRA_FULL_ERROR=1 + echo $PYTHONPATH source ~/mlflow_settings.sh From f4e625e41289396d8e1008c66e25bee0fbacef37 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 12:49:59 +0530 Subject: [PATCH 04/35] rmv import --- cli/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cli/train.py b/cli/train.py index 9ed1cd0..9aa497d 100644 --- a/cli/train.py +++ b/cli/train.py @@ -4,8 +4,6 @@ from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger -from enhancer.data.dataset import EnhancerDataset - @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): From fb31711653d20c7abf1f87b438b04da132e6d650 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 12:50:07 +0530 Subject: [PATCH 05/35] rmv import --- enhancer/models/demucs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 115f63e..0bb81d1 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,5 +1,3 @@ -from base64 import encode -from turtle import forward from typing import Optional, Union, List from torch import nn import torch.nn.functional as F From 52314a817c9da1677d8b72ca7c6f4382363ee5eb Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 15:07:36 +0530 Subject: [PATCH 06/35] rmv unused imports --- enhancer/inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/enhancer/inference.py b/enhancer/inference.py index 404fe95..6e9cff7 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -10,7 +10,6 @@ from pathlib import Path from librosa import load as load_audio from enhancer.utils import Audio -from enhancer.utils.config import DEFAULT_DEVICE class Inference: From 52788c029baca5840497be5db04d06bb1617b169 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 15:07:55 +0530 Subject: [PATCH 07/35] rmv dataclasses unused --- enhancer/utils/config.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index 0aaa2e3..e9af6a0 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -1,18 +1,11 @@ from dataclasses import dataclass -@dataclass -class Paths: - log : str - data : str - @dataclass class Files: + root_dir : str train_clean : str train_noisy : str test_clean : str test_noisy : str -@dataclass -class EnhancerConfig: - path : Paths - files: Files \ No newline at end of file + From ea3861eca92957be26417c83570bf3893e0d6fce Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 15:08:13 +0530 Subject: [PATCH 08/35] rmv np --- tests/test_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 303aa67..5eb7442 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,6 +1,5 @@ import pytest import torch -import numpy as np from enhancer.inference import Inference From 60b5d00bab9dc7df59c8655da3f379f2315c6924 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 21:51:00 +0530 Subject: [PATCH 09/35] set trainer to default --- cli/train_config/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 7845b01..6b5d98e 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -3,5 +3,5 @@ defaults: - dataset : Vctk - optimizer : Adam - hyperparameters : default - - trainer : fastrun_dev + - trainer : default - mlflow : experiment \ No newline at end of file From 658e4d08a5df00774392ce9e54ec7e9f1b9d680d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 28 Sep 2022 22:05:10 +0530 Subject: [PATCH 10/35] add num_workers as arg --- enhancer/data/dataset.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f4e7e4a..98abe8a 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,6 +1,4 @@ - -from dataclasses import dataclass -import glob +import multiprocessing import math import os import pytorch_lightning as pl @@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule): duration:float=1.0, sampling_rate:int=48000, matching_function = None, - batch_size=32): + batch_size=32, + num_workers:Optional[int]=None): super().__init__() self.name = name @@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule): self.batch_size = batch_size self.matching_function = matching_function self._validation = [] + if num_workers is None: + num_workers = multiprocessing.cpu_count()//2 + self.num_workers = num_workers def setup(self, stage: Optional[str] = None): @@ -85,10 +87,10 @@ class TaskDataset(pl.LightningDataModule): self._validation.append(({"clean":clean,"noisy":noisy}, start_time)) def train_dataloader(self): - return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2) + return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) def val_dataloader(self): - return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2) + return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) class EnhancerDataset(TaskDataset): """Dataset object for creating clean-noisy speech enhancement datasets""" @@ -101,7 +103,8 @@ class EnhancerDataset(TaskDataset): duration=1.0, sampling_rate=48000, matching_function=None, - batch_size=32): + batch_size=32, + num_workers:Optional[int]=None): super().__init__( name=name, @@ -110,7 +113,8 @@ class EnhancerDataset(TaskDataset): sampling_rate=sampling_rate, duration=duration, matching_function = matching_function, - batch_size=batch_size + batch_size=batch_size, + num_workers = num_workers, ) From 7be0a2d0bd98c06b12c1cb3182b27f35708ac6e5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 09:39:01 +0530 Subject: [PATCH 11/35] rename to yaml --- cli/train_config/trainer/{default.yml => default.yaml} | 1 - 1 file changed, 1 deletion(-) rename cli/train_config/trainer/{default.yml => default.yaml} (98%) diff --git a/cli/train_config/trainer/default.yml b/cli/train_config/trainer/default.yaml similarity index 98% rename from cli/train_config/trainer/default.yml rename to cli/train_config/trainer/default.yaml index eeb5b85..6c15867 100644 --- a/cli/train_config/trainer/default.yml +++ b/cli/train_config/trainer/default.yaml @@ -1,4 +1,3 @@ -# @package _group_ _target_: pytorch_lightning.Trainer accelerator: auto accumulate_grad_batches: 1 From 8271573e1c4c87545ad7a4e44c987d5c001c831a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 09:39:16 +0530 Subject: [PATCH 12/35] rmv depreciated --- cli/train_config/trainer/fastrun_dev.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/cli/train_config/trainer/fastrun_dev.yaml b/cli/train_config/trainer/fastrun_dev.yaml index 5d0895f..682149e 100644 --- a/cli/train_config/trainer/fastrun_dev.yaml +++ b/cli/train_config/trainer/fastrun_dev.yaml @@ -1,3 +1,2 @@ -# @package _group_ _target_: pytorch_lightning.Trainer fast_dev_run: True From a05efc7866a7094228a97bdc54be80087435870a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 10:14:25 +0530 Subject: [PATCH 13/35] change exp name --- cli/train_config/mlflow/experiment.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/train_config/mlflow/experiment.yaml b/cli/train_config/mlflow/experiment.yaml index b64b125..2995c60 100644 --- a/cli/train_config/mlflow/experiment.yaml +++ b/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ -experiment_name : "myexp" -run_name : "myrun" \ No newline at end of file +experiment_name : shahules/enhancer +run_name : baseline \ No newline at end of file From 31a3335ff026d3a806c5ceb8cd81df5ab5bb5047 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 10:16:59 +0530 Subject: [PATCH 14/35] do fastrun --- cli/train_config/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..7845b01 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -3,5 +3,5 @@ defaults: - dataset : Vctk - optimizer : Adam - hyperparameters : default - - trainer : default + - trainer : fastrun_dev - mlflow : experiment \ No newline at end of file From 206355270ea300156563304e0b25d8b76a6dfd68 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 10:36:47 +0530 Subject: [PATCH 15/35] update progess bar --- cli/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cli/train.py b/cli/train.py index 9aa497d..5a056a2 100644 --- a/cli/train.py +++ b/cli/train.py @@ -3,11 +3,14 @@ from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger +from pytorch_lightning.callbacks import TQDMProgressBar + @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] + callbacks.append(TQDMProgressBar(refresh_rate=10)) logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, run_name=config.mlflow.run_name) From 6fd5f964ddf31eddad536c3c1021084888e8fad4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 10:37:00 +0530 Subject: [PATCH 16/35] increase num epochs --- cli/train_config/hyperparameters/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 5cbdcb0..4931c7c 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,4 +1,4 @@ loss : mse metric : mae lr : 0.001 -num_epochs : 10 +num_epochs : 100 From 387fc0149394dc98b2afea16695f6fe3021c3b25 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 10:37:52 +0530 Subject: [PATCH 17/35] decrease num epochs --- cli/train_config/trainer/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index 6c15867..633c6ba 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -23,7 +23,7 @@ limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 log_every_n_steps: 50 -max_epochs: 1000 +max_epochs: 100 max_steps: null max_time: null min_epochs: 1 From 2cdd81af36091f249db3031bfb497ba3346dd5a3 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 10:38:18 +0530 Subject: [PATCH 18/35] trainer to default --- cli/train_config/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 7845b01..6b5d98e 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -3,5 +3,5 @@ defaults: - dataset : Vctk - optimizer : Adam - hyperparameters : default - - trainer : fastrun_dev + - trainer : default - mlflow : experiment \ No newline at end of file From cd052e77c70acf903be7f2b34054de1a88647123 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 11:15:53 +0530 Subject: [PATCH 19/35] BS to 16 --- cli/train_config/dataset/Vctk.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index d40f27f..70dda64 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -3,7 +3,7 @@ name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 sampling_rate: 48000 -batch_size: 32 +batch_size: 16 files: train_clean : clean_trainset_56spk_wav From 18759a3f843a29b6d34cf2fb5d7844c2eec1b54c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 12:15:21 +0530 Subject: [PATCH 20/35] halve BS --- cli/train_config/dataset/Vctk.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index 70dda64..07adc5d 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -3,7 +3,7 @@ name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 sampling_rate: 48000 -batch_size: 16 +batch_size: 8 files: train_clean : clean_trainset_56spk_wav From e22cecaf2080521eef200c1fb418f874a6ca8346 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 12:27:04 +0530 Subject: [PATCH 21/35] fix model sr to dataset sr --- enhancer/models/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index c4be077..980c583 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,5 +1,6 @@ from importlib import import_module from huggingface_hub import cached_download, hf_hub_url +import logging import numpy as np import os from typing import Optional, Union, List, Text, Dict, Any @@ -37,6 +38,9 @@ class Model(pl.LightningModule): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset + if self.dataset is not None: + sampling_rate = self.dataset.sampling_rate + logging.warn("Setting model sampling rate same as dataset sampling rate") self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") From fd526254416669ce3caf65ffaa0a7919bf664466 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 14:30:08 +0530 Subject: [PATCH 22/35] add job id to logging --- cli/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cli/train.py b/cli/train.py index 5a056a2..16677b4 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,3 +1,4 @@ +import os import hydra from hydra.utils import instantiate from omegaconf import DictConfig @@ -12,7 +13,7 @@ def main(config: DictConfig): callbacks = [] callbacks.append(TQDMProgressBar(refresh_rate=10)) logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name) + run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) parameters = config.hyperparameters From a0f70010f2b53452584530bab879f04a27590850 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 16:05:02 +0530 Subject: [PATCH 23/35] rmv print --- enhancer/models/waveunet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 89b4bb7..a6c0d34 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -125,7 +125,6 @@ class WaveUnet(Model): for layer,decoder in enumerate(self.decoders): out = F.interpolate(out, scale_factor=2, mode="linear") - print(out.shape,encoder_outputs[layer].shape) out = self.fix_last_dim(out,encoder_outputs[layer]) out = torch.cat([out,encoder_outputs[layer]],dim=1) out = decoder(out) From c1b67c1e3a60ca052ba68e04dd5ce7182120a67c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 16:11:45 +0530 Subject: [PATCH 24/35] use waveunet --- cli/train_config/config.yaml | 2 +- enhancer/models/waveunet.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index a6c0d34..b354f55 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -70,6 +70,8 @@ class WaveUnet(Model): loss: Union[str, List] = "mse", metric:Union[str,List] = "mse" ): + duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + sampling_rate = sampling_rate if dataset is None else dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric From 4e033d2ab5cc3dd1ccdb331f942a765aa8959a5a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 17:08:49 +0530 Subject: [PATCH 25/35] log metric --- enhancer/models/model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 980c583..64bf201 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -38,9 +38,6 @@ class Model(pl.LightningModule): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset - if self.dataset is not None: - sampling_rate = self.dataset.sampling_rate - logging.warn("Setting model sampling rate same as dataset sampling rate") self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") @@ -86,7 +83,7 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) + self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step) return {"loss":loss} @@ -98,7 +95,7 @@ class Model(pl.LightningModule): loss = self.metric(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) + self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step) return {"loss":loss} From 79525df76ef60bc83c68b5d8e4b1c2d1dc0f4058 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 17:20:34 +0530 Subject: [PATCH 26/35] set sr to dataset sr --- enhancer/models/demucs.py | 7 ++++++- enhancer/models/waveunet.py | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 0bb81d1..7c9d8ff 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,7 +1,8 @@ +import logging from typing import Optional, Union, List from torch import nn import torch.nn.functional as F -import math +import math from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset @@ -112,6 +113,10 @@ class Demucs(Model): ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + if dataset is not None: + if sampling_rate!=dataset.sampling_rate: + logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + sampling_rate = dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric) diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index b354f55..f799352 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,5 +1,4 @@ -from tkinter import wantobjects -import wave +import logging import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +70,10 @@ class WaveUnet(Model): metric:Union[str,List] = "mse" ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None - sampling_rate = sampling_rate if dataset is None else dataset.sampling_rate + if dataset is not None: + if sampling_rate!=dataset.sampling_rate: + logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + sampling_rate = dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric From 0bd5d3f994b5e644cce848eabb9cb04cfb148d13 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 17:52:16 +0530 Subject: [PATCH 27/35] fix log metric --- enhancer/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 64bf201..b030b23 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -83,7 +83,7 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step) + self.logger.experiment.log_metric(key="train_loss",value=loss.item(), step=self.global_step) return {"loss":loss} @@ -95,7 +95,7 @@ class Model(pl.LightningModule): loss = self.metric(prediction, target) if self.logger: - self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step) + self.logger.experiment.log_metric(key="val_loss",value=loss.item(), step=self.global_step) return {"loss":loss} From 192b8ffa7bd04b3c30eb592fb1e949e06365db06 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 18:01:50 +0530 Subject: [PATCH 28/35] downsample vctk --- cli/train_config/dataset/Vctk.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index 07adc5d..d1c8646 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -2,7 +2,7 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 -sampling_rate: 48000 +sampling_rate: 16000 batch_size: 8 files: From fccbd88ba290ad6ef97e54dc0d9c72bc8d1e0ce1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 22:38:26 +0530 Subject: [PATCH 29/35] rmv sr filtering --- enhancer/data/dataset.py | 6 ++---- enhancer/data/fileprocessor.py | 26 +++++++++++--------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 98abe8a..5749c36 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -66,15 +66,13 @@ class TaskDataset(pl.LightningDataModule): train_clean = os.path.join(self.root_dir,self.files.train_clean) train_noisy = os.path.join(self.root_dir,self.files.train_noisy) fp = Fileprocessor.from_name(self.name,train_clean, - train_noisy,self.sampling_rate, - self.matching_function) + train_noisy, self.matching_function) self.train_data = fp.prepare_matching_dict() val_clean = os.path.join(self.root_dir,self.files.test_clean) val_noisy = os.path.join(self.root_dir,self.files.test_noisy) fp = Fileprocessor.from_name(self.name,val_clean, - val_noisy,self.sampling_rate, - self.matching_function) + val_noisy, self.matching_function) val_data = fp.prepare_matching_dict() for item in val_data: diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 4df3e23..f903375 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,12 +1,13 @@ import glob import os +from re import S import numpy as np from scipy.io import wavfile class ProcessorFunctions: @staticmethod - def match_vtck(clean_path,noisy_path,sr): + def match_vtck(clean_path,noisy_path): matching_wavfiles = list() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -18,16 +19,15 @@ class ProcessorFunctions: sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr) and - (sr_noisy==sr)): + (sr_clean==sr_noisy)): matching_wavfiles.append( {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), - "duration":clean_file.shape[-1]/sr} + "duration":clean_file.shape[-1]/sr_clean} ) return matching_wavfiles @staticmethod - def match_dns2020(clean_path,noisy_path,sr): + def match_dns2020(clean_path,noisy_path): matching_wavfiles = dict() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -38,11 +38,10 @@ class ProcessorFunctions: sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_noisy, noisy_file = wavfile.read(noisy_file) if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr) and - (sr_noisy==sr)): + (sr_clean==sr_noisy)): matching_wavfiles.update( {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, - "duration":clean_file.shape[-1]/sr} + "duration":clean_file.shape[-1]/sr_clean} ) return matching_wavfiles @@ -54,12 +53,10 @@ class Fileprocessor: self, clean_dir, noisy_dir, - sr = 16000, matching_function = None ): self.clean_dir = clean_dir self.noisy_dir = noisy_dir - self.sr = sr self.matching_function = matching_function @classmethod @@ -67,23 +64,22 @@ class Fileprocessor: name:str, clean_dir, noisy_dir, - sr, matching_function=None ): if name.lower() == "vctk": - return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck) + return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020) + return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020) else: - return cls(clean_dir,noisy_dir,sr, matching_function) + return cls(clean_dir,noisy_dir, matching_function) def prepare_matching_dict(self): if self.matching_function is None: raise ValueError("Not a valid matching function") - return self.matching_function(self.clean_dir,self.noisy_dir,self.sr) + return self.matching_function(self.clean_dir,self.noisy_dir) From 5983b094701f2858c65f9d96ff6f0517b3c2fe14 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 22:38:59 +0530 Subject: [PATCH 30/35] fix logger --- enhancer/models/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index b030b23..20c8196 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -83,7 +83,9 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metric(key="train_loss",value=loss.item(), step=self.global_step) + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="train_loss", value=loss.item(), + step=self.global_step) return {"loss":loss} @@ -95,7 +97,9 @@ class Model(pl.LightningModule): loss = self.metric(prediction, target) if self.logger: - self.logger.experiment.log_metric(key="val_loss",value=loss.item(), step=self.global_step) + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="val_loss",value=loss.item(), + step=self.global_step) return {"loss":loss} From bd0bfbeea79af0d0990344fa67f3c2ed62fdc38a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 22:39:55 +0530 Subject: [PATCH 31/35] ignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ae420f9..6eb0fe3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ -##local +#local +cli/train_config/dataset/Vctk_local.yaml .DS_Store outputs/ datasets/ From c717e7c38cf22209a823d325f721f8e9e170ad0a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 10:10:35 +0530 Subject: [PATCH 32/35] log val loss --- enhancer/models/model.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 20c8196..8e607ed 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -86,7 +86,7 @@ class Model(pl.LightningModule): self.logger.experiment.log_metric(run_id=self.logger.run_id, key="train_loss", value=loss.item(), step=self.global_step) - + self.log("train_loss",loss.item()) return {"loss":loss} def validation_step(self,batch,batch_idx:int): @@ -95,13 +95,20 @@ class Model(pl.LightningModule): target = batch["clean"] prediction = self(mixed_waveform) - loss = self.metric(prediction, target) + metric_val = self.metric(prediction, target) + loss_val = self.loss(prediction, target) + self.log("val_metric",metric_val.item()) + self.log("val_loss",loss_val.item()) + if self.logger: self.logger.experiment.log_metric(run_id=self.logger.run_id, - key="val_loss",value=loss.item(), + key="val_loss",value=loss_val.item(), + step=self.global_step) + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="val_metric",value=metric_val.item(), step=self.global_step) - return {"loss":loss} + return {"loss":loss_val} def on_save_checkpoint(self, checkpoint): From 610c23a0eb18035a221def9bc4bff8debb89480b Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 10:10:51 +0530 Subject: [PATCH 33/35] change monitor --- cli/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cli/train.py b/cli/train.py index 16677b4..88e513a 100644 --- a/cli/train.py +++ b/cli/train.py @@ -4,14 +4,12 @@ from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger -from pytorch_lightning.callbacks import TQDMProgressBar @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] - callbacks.append(TQDMProgressBar(refresh_rate=10)) logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) @@ -23,12 +21,12 @@ def main(config: DictConfig): loss=parameters.get("loss"), metric = parameters.get("metric")) checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False, + dirpath="",filename="model",monitor="valid_loss",verbose=False, mode="min",every_n_epochs=1 ) callbacks.append(checkpoint) early_stopping = EarlyStopping( - monitor=parameters.get("loss"), + monitor="valid_loss", mode="min", min_delta=0.0, patience=100, From 7b0e7e2312be6faf603969bfe6c5441ab89c05e6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 10:15:28 +0530 Subject: [PATCH 34/35] change config --- cli/train_config/config.yaml | 2 +- cli/train_config/hyperparameters/default.yaml | 2 +- cli/train_config/trainer/default.yaml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 61551bd..6b5d98e 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : WaveUnet + - model : Demucs - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 4931c7c..04b099b 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,4 +1,4 @@ loss : mse metric : mae -lr : 0.001 +lr : 0.0001 num_epochs : 100 diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index 633c6ba..560305b 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -9,7 +9,7 @@ benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: auto +devices: -1 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True @@ -22,7 +22,7 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 50 +log_every_n_steps: 10 max_epochs: 100 max_steps: null max_time: null From 6aa502c3dcb2e53a2101fc17563b2a3580a9b5c9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 30 Sep 2022 10:27:11 +0530 Subject: [PATCH 35/35] wavenet trainig --- cli/train_config/config.yaml | 2 +- cli/train_config/trainer/default.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index 560305b..ab4e273 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -2,7 +2,7 @@ _target_: pytorch_lightning.Trainer accelerator: auto accumulate_grad_batches: 1 amp_backend: native -auto_lr_find: False +auto_lr_find: True auto_scale_batch_size: False auto_select_gpus: True benchmark: False