From 90fbfbce73752c392dbb2b89ec2aea6e2ffdb55c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:39:18 +0530 Subject: [PATCH 1/9] examples --- notebooks/Getting_started.ipynb | 324 +++++++++++++++++++++++++------- 1 file changed, 260 insertions(+), 64 deletions(-) diff --git a/notebooks/Getting_started.ipynb b/notebooks/Getting_started.ipynb index c9a47dd..b25b51f 100644 --- a/notebooks/Getting_started.ipynb +++ b/notebooks/Getting_started.ipynb @@ -30,6 +30,17 @@ "! pip install -q mayavoz " ] }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e3b59ac5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir(\"/Users/shahules/Myprojects/enhancer\")" + ] + }, { "cell_type": "markdown", "id": "87ee497f", @@ -62,14 +73,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "67698871", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/anaconda3/envs/enhancer/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "\n", "from mayavoz import Mayamodel\n", - "model = Mayamodel.from_pretrained(\"mayavoz/waveunet\")\n" + "model = Mayamodel.from_pretrained(\"shahules786/mayavoz-dccrn-valentini-28spk\")\n" ] }, { @@ -82,13 +102,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "d7996c16", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 36414])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "file = \"myvoice.wav\"\n", - "audio = model.enhance(\"myvoice.wav\")\n", + "audio = model.enhance(\"my_voice.wav\")\n", "audio.shape" ] }, @@ -96,19 +126,84 @@ "cell_type": "markdown", "id": "8ee20a83", "metadata": {}, + "source": [ + "**Inference using numpy ndarray**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e1a1c718", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(36414,)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "from librosa import load\n", + "my_voice,sr = load(\"my_voice.wav\",sr=16000)\n", + "my_voice.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "56b5c01b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 1, 36414)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "audio = model.enhance(my_voice,sampling_rate=sr)\n", + "audio.shape" + ] + }, + { + "cell_type": "markdown", + "id": "e0ab4d43", + "metadata": {}, "source": [ "**Inference using torch tensor**\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "e1a1c718", + "execution_count": 22, + "id": "fc6192b9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 36414])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "audio_tensor = torch.rand(1,1,32000) ## random audio data\n", - "audio = model.enhance(audio_tensor)\n", + "my_voice = torch.from_numpy(my_voice)\n", + "audio = model.enhance(my_voice,sampling_rate=sr)\n", "audio.shape" ] }, @@ -122,24 +217,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "9e0313f7", "metadata": {}, "outputs": [], "source": [ - "audio = model.enhance(\"myvoice.wav\",save_output=True)" + "audio = model.enhance(\"my_voice.wav\",save_output=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "25077720", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from Ipython.audio import Audio\n", - "\n", - "Audio(\"myvoice_cleaned.wav\",rate=SAMPLING_RATE)" + "from IPython.display import Audio\n", + "SAMPLING_RATE = 16000\n", + "Audio(\"cleaned_my_voice.wav\",rate=SAMPLING_RATE)" ] }, { @@ -183,19 +297,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "2c8c2b12", "metadata": {}, "outputs": [], "source": [ "from mayavoz.utils import Files\n", "\n", - "name = \"dataset_name\"\n", - "root_dir = \"root_directory_of_your_dataset\"\n", - "files = Files(train_clean=\"train_cleanfiles_foldername\",\n", - " train_noisy=\"noisy_train_foldername\",\n", - " test_clean=\"clean_test_foldername\",\n", - " test_noisy=\"noisy_test_foldername\")\n", + "name = \"valentini\"\n", + "root_dir = \"/Users/shahules/Myprojects/enhancer/datasets/vctk\"\n", + "files = Files(train_clean=\"clean_testset_wav\",\n", + " train_noisy=\"clean_testset_wav\",\n", + " test_clean=\"noisy_testset_wav\",\n", + " test_noisy=\"noisy_testset_wav\")\n", "duration = 4.0 \n", "stride = None\n", "sampling_rate = 16000" @@ -207,13 +321,13 @@ "metadata": {}, "source": [ "Now there are two types of `matching_function`\n", - "- `one_to_one` : In this one clean file will only have one corresponding noisy file. For example VCTK datasets\n", + "- `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n", "- `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "4b0fdc62", "metadata": {}, "outputs": [], @@ -223,25 +337,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "ff0cfe60", "metadata": {}, "outputs": [], "source": [ - "from mayavoz.dataset import MayaDataset\n", + "from mayavoz.data import MayaDataset\n", "dataset = MayaDataset(\n", " name=name,\n", " root_dir=root_dir,\n", " files=files,\n", " duration=duration,\n", " stride=stride,\n", - " sampling_rate=sampling_rate\n", + " sampling_rate=sampling_rate,\n", + " min_valid_minutes = 5.0,\n", " )\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "acfdc655", "metadata": {}, "outputs": [], @@ -252,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 12, "id": "4fabe46d", "metadata": {}, "outputs": [], @@ -262,13 +377,91 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "20d98ed0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected fp257 for valid\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "----------------------------------------\n", + "0 | _loss | LossWrapper | 0 \n", + "1 | encoder | ModuleList | 4.7 M \n", + "2 | decoder | ModuleList | 4.7 M \n", + "3 | de_lstm | DemucsLSTM | 24.8 M\n", + "----------------------------------------\n", + "34.2 M Trainable params\n", + "0 Non-trainable params\n", + "34.2 M Total params\n", + "136.866 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total train duration 27.4 minutes\n", + "Total validation duration 29.733333333333334 minutes\n", + "Total test duration 57.2 minutes\n", + "Epoch 0: 48%|▍| 13/27 [15:18<16:29, 70.66s/it, loss=0.0265, v_num=2, train_loss\n", + "Validation: 0it [00:00, ?it/s]\u001b[A\n", + "Validation: 0%| | 0/14 [00:00 Date: Tue, 15 Nov 2022 21:39:35 +0530 Subject: [PATCH 2/9] gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index cd1b1e9..a483f45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ #local +cleaned_my_voice.wav +lightning_logs/ +my_voice.wav +pretrained/ *.ckpt *_local.yaml cli/train_config/dataset/Vctk_local.yaml From 003bab91f95ab0b40df32ba25d080ee9d753e74a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:39:47 +0530 Subject: [PATCH 3/9] tests --- tests/models/demucs_test.py | 4 +++- tests/models/test_dccrn.py | 4 +++- tests/models/test_waveunet.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index e1203b7..0472cc2 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -15,7 +15,9 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset( + name="vctk", root_dir=root_dir, files=files, sampling_rate=16000 + ) return dataset diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py index bc2a039..b309cb7 100644 --- a/tests/models/test_dccrn.py +++ b/tests/models/test_dccrn.py @@ -15,7 +15,9 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset( + name="vctk", root_dir=root_dir, files=files, sampling_rate=16000 + ) return dataset diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index bc250d1..ca0af0a 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -15,7 +15,9 @@ def vctk_dataset(): test_clean="clean_testset_wav", test_noisy="noisy_testset_wav", ) - dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) + dataset = MayaDataset( + name="vctk", root_dir=root_dir, files=files, sampling_rate=16000 + ) return dataset From 2bfca78caac08aff50a632224b95921939178a66 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:42:02 +0530 Subject: [PATCH 4/9] fix duration --- mayavoz/models/dccrn.py | 6 +++--- mayavoz/models/demucs.py | 7 ++++--- mayavoz/models/model.py | 2 +- mayavoz/models/waveunet.py | 6 +++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 6b8646c..638aefe 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -1,4 +1,4 @@ -import logging +import warnings from typing import Any, List, Optional, Tuple, Union import torch @@ -140,11 +140,11 @@ class DCCRN(Mayamodel): metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index 8424f17..dbe584b 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -1,5 +1,5 @@ -import logging import math +import warnings from typing import List, Optional, Union import torch.nn.functional as F @@ -136,16 +136,17 @@ class Demucs(Mayamodel): normalize=True, lr: float = 1e-3, dataset: Optional[MayaDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", floor=1e-3, ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index e248b2c..5143d0b 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -24,7 +24,7 @@ CACHE_DIR = os.getenv( ) HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu" -SAVE_NAME = "enhancer" +SAVE_NAME = "mayavoz" class Mayamodel(pl.LightningModule): diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index c9acfda..0e2ec80 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -1,4 +1,4 @@ -import logging +import warnings from typing import List, Optional, Union import torch @@ -103,11 +103,11 @@ class WaveUnet(Mayamodel): metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate From b99ef95719eb2d5c5e483c753f518c1377359bf5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:45:07 +0530 Subject: [PATCH 5/9] train config --- mayavoz/cli/train_config/dataset/DNS-2020.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mayavoz/cli/train_config/dataset/DNS-2020.yaml b/mayavoz/cli/train_config/dataset/DNS-2020.yaml index 5c67be2..4be9d97 100644 --- a/mayavoz/cli/train_config/dataset/DNS-2020.yaml +++ b/mayavoz/cli/train_config/dataset/DNS-2020.yaml @@ -1,10 +1,10 @@ _target_: mayavoz.data.dataset.MayaDataset -root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 +root_dir : /Users/shahules/Myprojects/MS-SNSD duration : 2.0 sampling_rate: 16000 batch_size: 32 -valid_size: 0.05 +min_valid_minutes: 15 files: train_clean : CleanSpeech_training test_clean : CleanSpeech_training From 191c6a7499d2f3806553f3c5bba23557b14052b0 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:50:24 +0530 Subject: [PATCH 6/9] add warnings --- mayavoz/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mayavoz/loss.py b/mayavoz/loss.py index 9955efd..f9c8ec4 100644 --- a/mayavoz/loss.py +++ b/mayavoz/loss.py @@ -1,4 +1,4 @@ -import logging +import warnings import numpy as np import torch @@ -134,7 +134,7 @@ class Pesq: try: pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) except Exception as e: - logging.warning(f"{e} error occured while calculating PESQ") + warnings.warn(f"{e} error occured while calculating PESQ") return torch.tensor(np.mean(pesq_values)) From 434b44ddc9252c5b4036c24eb3baa0654c6a10cf Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:51:06 +0530 Subject: [PATCH 7/9] minor fixes --- mayavoz/data/dataset.py | 17 +++++++++++++++++ mayavoz/data/fileprocessor.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mayavoz/data/dataset.py b/mayavoz/data/dataset.py index d47967c..810d72d 100644 --- a/mayavoz/data/dataset.py +++ b/mayavoz/data/dataset.py @@ -1,6 +1,8 @@ import math import multiprocessing import os +import sys +import warnings from pathlib import Path from typing import Optional @@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule): self._validation = [] if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 + if num_workers is None: + num_workers = multiprocessing.cpu_count() // 2 + + if ( + num_workers > 0 + and sys.platform == "darwin" + and sys.version_info[0] >= 3 + and sys.version_info[1] >= 8 + ): + warnings.warn( + "num_workers > 0 is not supported with macOS and Python 3.8+: " + "setting num_workers = 0." + ) + num_workers = 0 + self.num_workers = num_workers if min_valid_minutes > 0.0: self.min_valid_minutes = min_valid_minutes diff --git a/mayavoz/data/fileprocessor.py b/mayavoz/data/fileprocessor.py index 5b099d4..9f6fbe5 100644 --- a/mayavoz/data/fileprocessor.py +++ b/mayavoz/data/fileprocessor.py @@ -93,7 +93,7 @@ class Fileprocessor: def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): if matching_function is None: - if name.lower() == "vctk": + if name.lower() in ("vctk", "valentini"): return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) From 7afe928ee151bd0123b86210105b3f3de0ab7f5c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:51:28 +0530 Subject: [PATCH 8/9] relative imports --- mayavoz/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mayavoz/__init__.py b/mayavoz/__init__.py index 5284146..7a5bacf 100644 --- a/mayavoz/__init__.py +++ b/mayavoz/__init__.py @@ -1 +1,2 @@ __import__("pkg_resources").declare_namespace(__name__) +from mayavoz.models import Mayamodel From 9ee809a047ce7eb3b42c25da5a4adfc3c79b3858 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 15 Nov 2022 21:51:45 +0530 Subject: [PATCH 9/9] rename to train --- mayavoz/cli/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mayavoz/cli/train.py b/mayavoz/cli/train.py index c00c024..8f12ea7 100644 --- a/mayavoz/cli/train.py +++ b/mayavoz/cli/train.py @@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0") @hydra.main(config_path="train_config", config_name="config") -def main(config: DictConfig): +def train(config: DictConfig): - OmegaConf.save(config, "config_log.yaml") + OmegaConf.save(config, "config.yaml") callbacks = [] logger = MLFlowLogger( @@ -96,7 +96,7 @@ def main(config: DictConfig): trainer.test(model) logger.experiment.log_artifact( - logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" + logger.run_id, f"{trainer.default_root_dir}/config.yaml" ) saved_location = os.path.join( @@ -117,4 +117,4 @@ def main(config: DictConfig): if __name__ == "__main__": - main() + train()