diff --git a/.flake8 b/.flake8 index 861f69a..abbbc73 100644 --- a/.flake8 +++ b/.flake8 @@ -6,4 +6,4 @@ ignore = E203, E266, E501, W503 max-line-length = 80 max-complexity = 18 select = B,C,E,F,W,T4,B9 -exclude = tools/kaldi_decoder \ No newline at end of file +exclude = tools/kaldi_decoder diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..807429c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ + +repos: + # # Clean Notebooks + # - repo: https://github.com/kynan/nbstripout + # rev: master + # hooks: + # - id: nbstripout + # Format Code + - repo: https://github.com/ambv/black + rev: 22.8.0 + hooks: + - id: black + + # Sort imports + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://gitlab.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: ['--ignore=E203,E501,F811,E712,W503'] + + # Formatting, Whitespace, etc + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=no'] diff --git a/README.md b/README.md index e462afa..743a823 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# enhancer \ No newline at end of file +# enhancer diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 814fa0f..cb3c7c1 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -1,11 +1,12 @@ import os from types import MethodType + import hydra from hydra.utils import instantiate from omegaconf import DictConfig -from torch.optim.lr_scheduler import ReduceLROnPlateau -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger +from torch.optim.lr_scheduler import ReduceLROnPlateau os.environ["HYDRA_FULL_ERROR"] = "1" JOB_ID = os.environ.get("SLURM_JOBID", "0") diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml index 61551bd..c0b2cf6 100644 --- a/enhancer/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -4,4 +4,4 @@ defaults: - optimizer : Adam - hyperparameters : default - trainer : default - - mlflow : experiment \ No newline at end of file + - mlflow : experiment diff --git a/enhancer/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml index f59cb2b..3bd0e67 100644 --- a/enhancer/cli/train_config/dataset/DNS-2020.yaml +++ b/enhancer/cli/train_config/dataset/DNS-2020.yaml @@ -10,4 +10,3 @@ files: test_clean : clean_test_wav train_noisy : clean_test_wav test_noisy : clean_test_wav - diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 129d9a8..5c19320 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -10,6 +10,3 @@ files: test_clean : clean_testset_wav train_noisy : noisy_trainset_28spk_wav test_noisy : noisy_testset_wav - - - diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml index b792b71..ba44597 100644 --- a/enhancer/cli/train_config/dataset/Vctk_local.yaml +++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml @@ -10,4 +10,4 @@ files: train_clean : clean_testset_wav test_clean : clean_testset_wav train_noisy : noisy_testset_wav - test_noisy : noisy_testset_wav \ No newline at end of file + test_noisy : noisy_testset_wav diff --git a/enhancer/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml index 82ac3c2..7e4cda3 100644 --- a/enhancer/cli/train_config/hyperparameters/default.yaml +++ b/enhancer/cli/train_config/hyperparameters/default.yaml @@ -5,4 +5,3 @@ ReduceLr_patience : 5 ReduceLr_factor : 0.1 min_lr : 0.000001 EarlyStopping_factor : 10 - diff --git a/enhancer/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml index 2995c60..e8893f6 100644 --- a/enhancer/cli/train_config/mlflow/experiment.yaml +++ b/enhancer/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ experiment_name : shahules/enhancer -run_name : baseline \ No newline at end of file +run_name : baseline diff --git a/enhancer/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml index 1006e71..d91b5ff 100644 --- a/enhancer/cli/train_config/model/Demucs.yaml +++ b/enhancer/cli/train_config/model/Demucs.yaml @@ -14,5 +14,3 @@ encoder_decoder: lstm: bidirectional: False num_layers: 2 - - diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d194167..95c73a1 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,16 +1,17 @@ -import multiprocessing import math +import multiprocessing import os -import pytorch_lightning as pl -from torch.utils.data import IterableDataset, DataLoader, Dataset -import torch.nn.functional as F from typing import Optional +import pytorch_lightning as pl +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, IterableDataset + from enhancer.data.fileprocessor import Fileprocessor -from enhancer.utils.random import create_unique_rng -from enhancer.utils.io import Audio from enhancer.utils import check_files from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.random import create_unique_rng class TrainDataset(IterableDataset): diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 5cc9b31..66d4d75 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,5 +1,6 @@ import glob import os + import numpy as np from scipy.io import wavfile diff --git a/enhancer/inference.py b/enhancer/inference.py index 1abd8bb..ae399f1 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -1,11 +1,12 @@ -import numpy as np -from scipy.signal import get_window -from scipy.io import wavfile +from pathlib import Path from typing import Optional, Union + +import numpy as np import torch import torch.nn.functional as F -from pathlib import Path from librosa import load as load_audio +from scipy.io import wavfile +from scipy.signal import get_window from enhancer.utils import Audio diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 368a9d7..2d97568 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,3 +1,3 @@ from enhancer.models.demucs import Demucs -from enhancer.models.waveunet import WaveUnet from enhancer.models.model import Model +from enhancer.models.waveunet import WaveUnet diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 76a0bf7..65f119d 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,11 +1,12 @@ import logging -from typing import Optional, Union, List -from torch import nn -import torch.nn.functional as F import math +from typing import List, Optional, Union + +import torch.nn.functional as F +from torch import nn -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 56f24db..39dbe80 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,20 +1,20 @@ -from importlib import import_module -from huggingface_hub import cached_download, hf_hub_url -import numpy as np import os -from typing import Optional, Union, List, Text, Dict, Any -from torch.optim import Adam -import torch -import pytorch_lightning as pl -from pytorch_lightning.utilities.cloud_io import load as pl_load -from urllib.parse import urlparse +from importlib import import_module from pathlib import Path +from typing import Any, Dict, List, Optional, Text, Union +from urllib.parse import urlparse +import numpy as np +import pytorch_lightning as pl +import torch +from huggingface_hub import cached_download, hf_hub_url +from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch.optim import Adam from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset -from enhancer.loss import Avergeloss from enhancer.inference import Inference +from enhancer.loss import Avergeloss CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -293,7 +293,7 @@ class Model(pl.LightningModule): with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id: batch_id + batch_size, :, :].to( + batch_data = batch[batch_id : batch_id + batch_size, :, :].to( self.device ) prediction = self(batch_data) diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 4d5cc0a..ebb4b1f 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,11 +1,12 @@ import logging +from typing import List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Union, List -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model class WavenetDecoder(nn.Module): diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index c9f5438..de0db9f 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,3 +1,3 @@ -from enhancer.utils.utils import check_files -from enhancer.utils.io import Audio from enhancer.utils.config import Files +from enhancer.utils.io import Audio +from enhancer.utils.utils import check_files diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 3703285..9e9ce32 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,7 +1,8 @@ import os -import librosa from pathlib import Path from typing import Optional, Union + +import librosa import numpy as np import torch import torchaudio diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 51e09c0..dd9395a 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -1,5 +1,6 @@ import os import random + import torch diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index ebb41b4..ad45139 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -1,5 +1,6 @@ import os from typing import Optional + from enhancer.utils.config import Files diff --git a/environment.yml b/environment.yml index 4f211bf..8da22e1 100644 --- a/environment.yml +++ b/environment.yml @@ -5,4 +5,4 @@ dependencies: - python=3.8 - pip: - -r requirements.txt - - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html \ No newline at end of file + - --find-links https://download.pytorch.org/whl/cu113/torch_stable.html diff --git a/hpc_entrypoint.sh b/hpc_entrypoint.sh index 7372eb9..6d6a3a0 100644 --- a/hpc_entrypoint.sh +++ b/hpc_entrypoint.sh @@ -33,7 +33,7 @@ mkdir temp pwd #python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TRAIN --output ./data/train -#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test +#python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test echo "Start Training..." python cli/train.py diff --git a/pyproject.toml b/pyproject.toml index 8f12f30..b3e5d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,4 +12,4 @@ exclude = ''' | \.venv )/ ) -''' \ No newline at end of file +''' diff --git a/requirements.txt b/requirements.txt index b2f3638..afa3641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ -joblib==1.2.0 -librosa==0.9.2 -numpy==1.23.3 -hydra-core==1.2.0 -scikit-learn==1.1.2 -scipy==1.9.1 -torch==1.12.1 -tqdm==4.64.1 -mlflow==1.29.0 -protobuf==3.19.6 -boto3==1.24.86 -torchaudio==0.12.1 -huggingface-hu==0.10.0 -pytorch-lightning==1.7.7 -flake8==5.0.4 -black==22.8.0 \ No newline at end of file +black>=22.8.0 +boto3>=1.24.86 +flake8>=5.0.4 +huggingface-hu>=0.10.0 +hydra-core>=1.2.0 +joblib>=1.2.0 +librosa>=0.9.2 +mlflow>=1.29.0 +numpy>=1.23.3 +protobuf>=3.19.6 +pytorch-lightning>=1.7.7 +scikit-learn>=1.1.2 +scipy>=1.9.1 +torch>=1.12.1 +torchaudio>=0.12.1 +tqdm>=4.64.1 diff --git a/setup.sh b/setup.sh index adad46c..43adc89 100644 --- a/setup.sh +++ b/setup.sh @@ -10,4 +10,4 @@ conda env create -f environment.yml || conda env update -f environment.yml source activate enhancer echo "copying files" -# cp /scratch/$USER/TIMIT/.* /deep-transcriber \ No newline at end of file +# cp /scratch/$USER/TIMIT/.* /deep-transcriber diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index a4fdc62..4d14871 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -1,6 +1,5 @@ -from asyncio import base_tasks -import torch import pytest +import torch from enhancer.loss import mean_absolute_error, mean_squared_error diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index 6660888..f5a0ec4 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -1,10 +1,9 @@ import pytest import torch -from enhancer import data -from enhancer.utils.config import Files -from enhancer.models import Demucs from enhancer.data.dataset import EnhancerDataset +from enhancer.models import Demucs +from enhancer.utils.config import Files @pytest.fixture @@ -41,4 +40,4 @@ def test_forward(batch_size, samples): ) def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = Demucs(num_channels=channels, dataset=dataset, loss=loss) + _ = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index c83966b..9c4dd96 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -1,10 +1,9 @@ import pytest import torch -from enhancer import data -from enhancer.utils.config import Files -from enhancer.models import WaveUnet from enhancer.data.dataset import EnhancerDataset +from enhancer.models import WaveUnet +from enhancer.utils.config import Files @pytest.fixture @@ -41,4 +40,4 @@ def test_forward(batch_size, samples): ) def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) + _ = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/utils_test.py b/tests/utils_test.py index 93a9094..65c723d 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,11 +1,9 @@ -from logging import root +import numpy as np import pytest import torch -import numpy as np -from enhancer.utils.io import Audio -from enhancer.utils.config import Files from enhancer.data.fileprocessor import Fileprocessor +from enhancer.utils.io import Audio def test_io_channel(): @@ -47,6 +45,6 @@ def test_fileprocessor_names(dataset_name): def test_fileprocessor_invaliname(): with pytest.raises(ValueError): - fp = Fileprocessor.from_name( + _ = Fileprocessor.from_name( "undefined", "clean_dir", "noisy_dir", 16000 ).prepare_matching_dict()