diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6b3c716..4c64745 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -37,13 +37,15 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + sudo apt-get install libsndfile1 pip install -r requirements.txt - pip install black flake8 pytest-cov + pip install black pytest-cov + - name: Install enhancer + run: | + pip install -e .[dev,testing] - name: Run black run: - black --check . - - name: Run flake8 - run: flake8 + black --check . --exclude enhancer/version.py - name: Test with pytest run: pytest tests --cov=enhancer/ diff --git a/.gitignore b/.gitignore index 6eb0fe3..9cd222c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ #local +*.ckpt cli/train_config/dataset/Vctk_local.yaml .DS_Store outputs/ diff --git a/enhancer/__init__.py b/enhancer/__init__.py index f102a9c..5284146 100644 --- a/enhancer/__init__.py +++ b/enhancer/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1" +__import__("pkg_resources").declare_namespace(__name__) diff --git a/enhancer/inference.py b/enhancer/inference.py index fd3b518..d9282fd 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -139,7 +139,9 @@ class Inference: if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: - wavfile.write(filename, rate=sr, data=waveform.detach().cpu()) + wavfile.write( + filename, rate=sr, data=waveform.detach().cpu().numpy() + ) @staticmethod def prepare_output( diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 6e6b4e1..7ff15e4 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url from pytorch_lightning.utilities.cloud_io import load as pl_load from torch.optim import Adam -from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset from enhancer.inference import Inference from enhancer.loss import Avergeloss +from enhancer.version import __version__ CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -120,7 +120,11 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) - if self.logger: + if ( + (self.logger) + and (self.global_step > 50) + and (self.global_step % 50 == 0) + ): self.logger.experiment.log_metric( run_id=self.logger.run_id, key="train_loss", @@ -141,7 +145,11 @@ class Model(pl.LightningModule): self.log("val_metric", metric_val.item()) self.log("val_loss", loss_val.item()) - if self.logger: + if ( + (self.logger) + and (self.global_step > 50) + and (self.global_step % 50 == 0) + ): self.logger.experiment.log_metric( run_id=self.logger.run_id, key="val_loss", @@ -209,8 +217,7 @@ class Model(pl.LightningModule): to True or to a string containing your hugginface.co authentication token that can be obtained by running `huggingface-cli login` cache_dir: Path or str, optional - Path to model cache directory. Defaults to content of PYANNOTE_CACHE - environment variable, or "~/.cache/torch/pyannote" when unset. + Path to model cache directory kwargs: optional Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. @@ -290,10 +297,9 @@ class Model(pl.LightningModule): ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) - with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + batch_data = batch[batch_id : (batch_id + batch_size), :, :].to( self.device ) prediction = self(batch_data) diff --git a/enhancer/version.py b/enhancer/version.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/enhancer/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/requirements.txt b/requirements.txt index bb13983..3762fd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,11 @@ librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 protobuf>=3.19.6 +pytest-lazy-fixture>=0.6.3 pytorch-lightning>=1.7.7 scikit-learn>=1.1.2 scipy>=1.9.1 +soundfile>=0.11.0 torch>=1.12.1 torchaudio>=0.12.1 tqdm>=4.64.1 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..309ac9a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,100 @@ +# This file is used to configure your project. +# Read more about the various options under: +# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files + +[metadata] +name = enhancer +description = Deep learning for speech enhacement +author = Shahul Ess +author-email = shahules786@gmail.com +license = mit +long-description = file: README.md +long-description-content-type = text/markdown; charset=UTF-8; variant=GFM +# Change if running only on Windows, Mac or Linux (comma-separated) +platforms = Linux, Mac +# Add here all kinds of additional classifiers as defined under +# https://pypi.python.org/pypi?%3Aaction=list_classifiers +classifiers = + Development Status :: 4 - Beta + Programming Language :: Python + +[options] +zip_safe = False +packages = find: +include_package_data = True +# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! +setup_requires = setuptools +# Add here dependencies of your project (semicolon/line-separated), e.g. +# install_requires = numpy; scipy +# Require a specific Python version, e.g. Python 2.7 or >= 3.4 +python_requires = >=3.8 + +[options.packages.find] +where = . +exclude = + tests + +[options.extras_require] +# Add here additional requirements for extra features, to install with: +# `pip install fastaudio[PDF]` like: +# PDF = ReportLab; RXP +# Add here test requirements (semicolon/line-separated) +testing = + pytest>=7.1.3 + pytest-cov>=4.0.0 +dev = + pre-commit>=2.20.0 + black>=22.8.0 + flake8>=5.0.4 +cli = + hydra-core >=1.1,<=1.2 + + +[options.entry_points] + +console_scripts = + enhancer-train=enhancer.cli.train:train + +[test] +# py.test options when running `python setup.py test` +# addopts = --verbose +extras = True + +[tool:pytest] +# Options for py.test: +# Specify command line options as you would do when invoking py.test directly. +# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml +# in order to write a coverage file that can be read by Jenkins. +addopts = + --cov enhancer --cov-report term-missing + --verbose +norecursedirs = + dist + build + .tox +testpaths = tests + +[aliases] +dists = bdist_wheel + +[bdist_wheel] +# Use this option if your package is pure-python +universal = 1 + +[build_sphinx] +source_dir = doc +build_dir = build/sphinx + +[devpi:upload] +# Options for the devpi: PyPI server and packaging tool +# VCS export must be deactivated since we are using setuptools-scm +no-vcs = 1 +formats = bdist_wheel + +[flake8] +# Some sane defaults for the code style checker flake8 +exclude = + .tox + build + dist + .eggs diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..79282b3 --- /dev/null +++ b/setup.py @@ -0,0 +1,63 @@ +import os +import sys +from pathlib import Path + +from pkg_resources import VersionConflict, require +from setuptools import find_packages, setup + +with open("README.md") as f: + long_description = f.read() + +with open("requirements.txt") as f: + requirements = f.read().splitlines() + +try: + require("setuptools>=38.3") +except VersionConflict: + print("Error: version of setuptools is too old (<38.3)!") + sys.exit(1) + + +ROOT_DIR = Path(__file__).parent.resolve() +# Creating the version file + +with open("version.txt") as f: + version = f.read() + +version = version.strip() +sha = "Unknown" + +if os.getenv("BUILD_VERSION"): + version = os.getenv("BUILD_VERSION") +elif sha != "Unknown": + version += "+" + sha[:7] +print("-- Building version " + version) + +version_path = ROOT_DIR / "enhancer" / "version.py" + +with open(version_path, "w") as f: + f.write("__version__ = '{}'\n".format(version)) + +if __name__ == "__main__": + setup( + name="enhancer", + namespace_packages=["enhancer"], + version=version, + packages=find_packages(), + install_requires=requirements, + description="Deep learning toolkit for speech enhancement", + long_description=long_description, + long_description_content_type="text/markdown", + author="Shahul Es", + author_email="shahules786@gmail.com", + url="", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + ], + ) diff --git a/version.txt b/version.txt new file mode 100644 index 0000000..8acdd82 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.0.1