diff --git a/.gitignore b/.gitignore index cd1b1e9..a483f45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ #local +cleaned_my_voice.wav +lightning_logs/ +my_voice.wav +pretrained/ *.ckpt *_local.yaml cli/train_config/dataset/Vctk_local.yaml diff --git a/mayavoz/__init__.py b/mayavoz/__init__.py index 5284146..7a5bacf 100644 --- a/mayavoz/__init__.py +++ b/mayavoz/__init__.py @@ -1 +1,2 @@ __import__("pkg_resources").declare_namespace(__name__) +from mayavoz.models import Mayamodel diff --git a/mayavoz/cli/train.py b/mayavoz/cli/train.py index c00c024..8f12ea7 100644 --- a/mayavoz/cli/train.py +++ b/mayavoz/cli/train.py @@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0") @hydra.main(config_path="train_config", config_name="config") -def main(config: DictConfig): +def train(config: DictConfig): - OmegaConf.save(config, "config_log.yaml") + OmegaConf.save(config, "config.yaml") callbacks = [] logger = MLFlowLogger( @@ -96,7 +96,7 @@ def main(config: DictConfig): trainer.test(model) logger.experiment.log_artifact( - logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" + logger.run_id, f"{trainer.default_root_dir}/config.yaml" ) saved_location = os.path.join( @@ -117,4 +117,4 @@ def main(config: DictConfig): if __name__ == "__main__": - main() + train() diff --git a/mayavoz/cli/train_config/dataset/DNS-2020.yaml b/mayavoz/cli/train_config/dataset/DNS-2020.yaml index 5c67be2..4be9d97 100644 --- a/mayavoz/cli/train_config/dataset/DNS-2020.yaml +++ b/mayavoz/cli/train_config/dataset/DNS-2020.yaml @@ -1,10 +1,10 @@ _target_: mayavoz.data.dataset.MayaDataset -root_dir : /Users/shahules/Myprojects/MS-SNSD name : dns-2020 +root_dir : /Users/shahules/Myprojects/MS-SNSD duration : 2.0 sampling_rate: 16000 batch_size: 32 -valid_size: 0.05 +min_valid_minutes: 15 files: train_clean : CleanSpeech_training test_clean : CleanSpeech_training diff --git a/mayavoz/data/dataset.py b/mayavoz/data/dataset.py index d47967c..810d72d 100644 --- a/mayavoz/data/dataset.py +++ b/mayavoz/data/dataset.py @@ -1,6 +1,8 @@ import math import multiprocessing import os +import sys +import warnings from pathlib import Path from typing import Optional @@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule): self._validation = [] if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 + if num_workers is None: + num_workers = multiprocessing.cpu_count() // 2 + + if ( + num_workers > 0 + and sys.platform == "darwin" + and sys.version_info[0] >= 3 + and sys.version_info[1] >= 8 + ): + warnings.warn( + "num_workers > 0 is not supported with macOS and Python 3.8+: " + "setting num_workers = 0." + ) + num_workers = 0 + self.num_workers = num_workers if min_valid_minutes > 0.0: self.min_valid_minutes = min_valid_minutes diff --git a/mayavoz/data/fileprocessor.py b/mayavoz/data/fileprocessor.py index 5b099d4..9f6fbe5 100644 --- a/mayavoz/data/fileprocessor.py +++ b/mayavoz/data/fileprocessor.py @@ -93,7 +93,7 @@ class Fileprocessor: def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): if matching_function is None: - if name.lower() == "vctk": + if name.lower() in ("vctk", "valentini"): return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) diff --git a/mayavoz/loss.py b/mayavoz/loss.py index 9955efd..f9c8ec4 100644 --- a/mayavoz/loss.py +++ b/mayavoz/loss.py @@ -1,4 +1,4 @@ -import logging +import warnings import numpy as np import torch @@ -134,7 +134,7 @@ class Pesq: try: pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) except Exception as e: - logging.warning(f"{e} error occured while calculating PESQ") + warnings.warn(f"{e} error occured while calculating PESQ") return torch.tensor(np.mean(pesq_values)) diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 6b8646c..638aefe 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -1,4 +1,4 @@ -import logging +import warnings from typing import Any, List, Optional, Tuple, Union import torch @@ -140,11 +140,11 @@ class DCCRN(Mayamodel): metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index 8424f17..dbe584b 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -1,5 +1,5 @@ -import logging import math +import warnings from typing import List, Optional, Union import torch.nn.functional as F @@ -136,16 +136,17 @@ class Demucs(Mayamodel): normalize=True, lr: float = 1e-3, dataset: Optional[MayaDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", floor=1e-3, ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index e248b2c..5143d0b 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -24,7 +24,7 @@ CACHE_DIR = os.getenv( ) HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu" -SAVE_NAME = "enhancer" +SAVE_NAME = "mayavoz" class Mayamodel(pl.LightningModule): diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index c9acfda..0e2ec80 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -1,4 +1,4 @@ -import logging +import warnings from typing import List, Optional, Union import torch @@ -103,11 +103,11 @@ class WaveUnet(Mayamodel): metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/notebooks/Getting_started.ipynb b/notebooks/Getting_started.ipynb index c9a47dd..b25b51f 100644 --- a/notebooks/Getting_started.ipynb +++ b/notebooks/Getting_started.ipynb @@ -30,6 +30,17 @@ "! pip install -q mayavoz " ] }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e3b59ac5", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir(\"/Users/shahules/Myprojects/enhancer\")" + ] + }, { "cell_type": "markdown", "id": "87ee497f", @@ -62,14 +73,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "67698871", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/anaconda3/envs/enhancer/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "\n", "from mayavoz import Mayamodel\n", - "model = Mayamodel.from_pretrained(\"mayavoz/waveunet\")\n" + "model = Mayamodel.from_pretrained(\"shahules786/mayavoz-dccrn-valentini-28spk\")\n" ] }, { @@ -82,13 +102,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "d7996c16", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 36414])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "file = \"myvoice.wav\"\n", - "audio = model.enhance(\"myvoice.wav\")\n", + "audio = model.enhance(\"my_voice.wav\")\n", "audio.shape" ] }, @@ -96,19 +126,84 @@ "cell_type": "markdown", "id": "8ee20a83", "metadata": {}, + "source": [ + "**Inference using numpy ndarray**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e1a1c718", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(36414,)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "from librosa import load\n", + "my_voice,sr = load(\"my_voice.wav\",sr=16000)\n", + "my_voice.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "56b5c01b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 1, 36414)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "audio = model.enhance(my_voice,sampling_rate=sr)\n", + "audio.shape" + ] + }, + { + "cell_type": "markdown", + "id": "e0ab4d43", + "metadata": {}, "source": [ "**Inference using torch tensor**\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "e1a1c718", + "execution_count": 22, + "id": "fc6192b9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 36414])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "audio_tensor = torch.rand(1,1,32000) ## random audio data\n", - "audio = model.enhance(audio_tensor)\n", + "my_voice = torch.from_numpy(my_voice)\n", + "audio = model.enhance(my_voice,sampling_rate=sr)\n", "audio.shape" ] }, @@ -122,24 +217,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "9e0313f7", "metadata": {}, "outputs": [], "source": [ - "audio = model.enhance(\"myvoice.wav\",save_output=True)" + "audio = model.enhance(\"my_voice.wav\",save_output=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "25077720", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from Ipython.audio import Audio\n", - "\n", - "Audio(\"myvoice_cleaned.wav\",rate=SAMPLING_RATE)" + "from IPython.display import Audio\n", + "SAMPLING_RATE = 16000\n", + "Audio(\"cleaned_my_voice.wav\",rate=SAMPLING_RATE)" ] }, { @@ -183,19 +297,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "2c8c2b12", "metadata": {}, "outputs": [], "source": [ "from mayavoz.utils import Files\n", "\n", - "name = \"dataset_name\"\n", - "root_dir = \"root_directory_of_your_dataset\"\n", - "files = Files(train_clean=\"train_cleanfiles_foldername\",\n", - " train_noisy=\"noisy_train_foldername\",\n", - " test_clean=\"clean_test_foldername\",\n", - " test_noisy=\"noisy_test_foldername\")\n", + "name = \"valentini\"\n", + "root_dir = \"/Users/shahules/Myprojects/enhancer/datasets/vctk\"\n", + "files = Files(train_clean=\"clean_testset_wav\",\n", + " train_noisy=\"clean_testset_wav\",\n", + " test_clean=\"noisy_testset_wav\",\n", + " test_noisy=\"noisy_testset_wav\")\n", "duration = 4.0 \n", "stride = None\n", "sampling_rate = 16000" @@ -207,13 +321,13 @@ "metadata": {}, "source": [ "Now there are two types of `matching_function`\n", - "- `one_to_one` : In this one clean file will only have one corresponding noisy file. For example VCTK datasets\n", + "- `one_to_one` : In this one clean file will only have one corresponding noisy file. For example Valentini datasets\n", "- `one_to_many` : In this one clean file will only have one corresponding noisy file. For example DNS dataset." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "4b0fdc62", "metadata": {}, "outputs": [], @@ -223,25 +337,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "ff0cfe60", "metadata": {}, "outputs": [], "source": [ - "from mayavoz.dataset import MayaDataset\n", + "from mayavoz.data import MayaDataset\n", "dataset = MayaDataset(\n", " name=name,\n", " root_dir=root_dir,\n", " files=files,\n", " duration=duration,\n", " stride=stride,\n", - " sampling_rate=sampling_rate\n", + " sampling_rate=sampling_rate,\n", + " min_valid_minutes = 5.0,\n", " )\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "acfdc655", "metadata": {}, "outputs": [], @@ -252,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 12, "id": "4fabe46d", "metadata": {}, "outputs": [], @@ -262,13 +377,91 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "20d98ed0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected fp257 for valid\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "----------------------------------------\n", + "0 | _loss | LossWrapper | 0 \n", + "1 | encoder | ModuleList | 4.7 M \n", + "2 | decoder | ModuleList | 4.7 M \n", + "3 | de_lstm | DemucsLSTM | 24.8 M\n", + "----------------------------------------\n", + "34.2 M Trainable params\n", + "0 Non-trainable params\n", + "34.2 M Total params\n", + "136.866 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total train duration 27.4 minutes\n", + "Total validation duration 29.733333333333334 minutes\n", + "Total test duration 57.2 minutes\n", + "Epoch 0: 48%|▍| 13/27 [15:18<16:29, 70.66s/it, loss=0.0265, v_num=2, train_loss\n", + "Validation: 0it [00:00, ?it/s]\u001b[A\n", + "Validation: 0%| | 0/14 [00:00